import { gridColumnPositionsSelector, gridColumnsTotalWidthSelector, useGridApiContext } from '@mui/x-data-grid';
import React from 'react';
import midTheme from '../../styles/midTheme';
import { SkeletonCell, SkeletonGridWrapper, SkeletonWrapper } from './TableLoadingState.styles';

const useTableLoadingState: React.FC<{ height?: number }> = ({ height }) => {
  const apiRef = useGridApiContext();

  const dimensions = apiRef.current?.getRootDimensions();
  const viewportHeight = height ? height * 2 : (dimensions?.viewportInnerSize.height ?? 0);
  const rowHeight = midTheme.var.muiTableRowHeight;
  const skeletonRowsCount = Math.ceil(viewportHeight / rowHeight);

  const totalWidth = gridColumnsTotalWidthSelector(apiRef);
  const positions = gridColumnPositionsSelector(apiRef);
  const inViewportCount = React.useMemo(
    () => positions.filter((value) => value <= totalWidth).length,
    [totalWidth, positions],
  );
  const columns = apiRef.current.getVisibleColumns().slice(0, inViewportCount);

  const children = React.useMemo(() => {
    const array: React.ReactNode[] = [];

    for (let i = 0; i < skeletonRowsCount; i++) {
      for (const column of columns) {
        array.push(
          <SkeletonCell key={`column-${i}-${column.field}`} column={column}>
            <SkeletonWrapper />
          </SkeletonCell>,
        );
      }
      array.push(<SkeletonCell key={`row-${i}`} />);
    }
    return array;
  }, [skeletonRowsCount, columns]);

  return (
    <SkeletonGridWrapper columns={columns} rowHeight={rowHeight}>
      {children}
    </SkeletonGridWrapper>
  );
};

export default useTableLoadingState;
