import React, { useLayoutEffect, useRef, useState } from "react";
import {
  Cell,
  ColumnDef,
  flexRender,
  getCoreRowModel,
  Row,
  useReactTable,
} from "@tanstack/react-table";
import {
  useVirtualizer,
  VirtualItem,
  Virtualizer,
} from "@tanstack/react-virtual";
import classNames from "classnames";

import LoadingLogo from "./LoadingLogo";
import * as styles from "../styles/table.module.scss";

const baseRowHeight = 100;

interface TableProps<RowData extends {}> {
  containerClassName?: string;
  tableClassName?: string;
  getRowClassName?: (row: RowData) => string;
  getCellClassName?: (cell: Cell<RowData, unknown>) => string;
  loading?: boolean;
  error?: boolean;
  data: RowData[];
  columns: ColumnDef<RowData>[];
  /** Which column should grow in the display flexbox */
  growColumnIndex?: number;
}

const Table = <RowData extends {}>({
  containerClassName,
  tableClassName,
  getRowClassName,
  getCellClassName,
  loading,
  error,
  data,
  columns,
  growColumnIndex,
}: TableProps<RowData>) => {
  const containerRef = useRef<HTMLDivElement>(null);

  const [maxHeight, setMaxHeight] = useState<number | undefined>(undefined);

  const table = useReactTable({
    data,
    columns,
    getCoreRowModel: getCoreRowModel(),
    columnResizeMode: "onChange",
  });

  useLayoutEffect(() => {
    const container = containerRef.current;

    if (!container) return;

    const handleResize = () => {
      const offsetTop = container.getBoundingClientRect().top;
      const viewportHeight = window.innerHeight;
      const calculatedMaxHeight = viewportHeight - offsetTop - 250;
      setMaxHeight(calculatedMaxHeight);
    };

    handleResize();
    window.addEventListener("resize", handleResize);

    return () => {
      window.removeEventListener("resize", handleResize);
    };
  }, [containerRef.current]);

  return (
    <div
      ref={containerRef}
      className={classNames(containerClassName, styles.tableContainer)}
      style={{ maxHeight }}
    >
      <table
        className={classNames(tableClassName, styles.table, {
          [styles.loadingTable]: loading,
        })}
      >
        <thead>
          {table.getHeaderGroups().map((headerGroup) => {
            return (
              <tr key={headerGroup.id}>
                {headerGroup.headers.map((header, headerIndex) => {
                  return (
                    <th
                      key={header.id}
                      colSpan={header.colSpan}
                      style={{
                        flexGrow:
                          headerIndex === growColumnIndex ? 1 : undefined,
                        width: header.getSize(),
                      }}
                    >
                      {header.isPlaceholder ? null : (
                        <>
                          {flexRender(
                            header.column.columnDef.header,
                            header.getContext()
                          )}
                        </>
                      )}
                    </th>
                  );
                })}
              </tr>
            );
          })}
        </thead>
        {loading || error ? (
          <tbody>
            <tr>
              <td colSpan={100}>
                {error ? (
                  <p className={styles.errorText}>Error loading data</p>
                ) : (
                  <LoadingLogo className={styles.loadingIcon} />
                )}
              </td>
            </tr>
          </tbody>
        ) : (
          <TableBody
            getRowClassName={getRowClassName}
            getCellClassName={getCellClassName}
            table={table}
            tableContainerRef={containerRef}
            growColumnIndex={growColumnIndex}
          />
        )}
      </table>
    </div>
  );
};

interface TableBodyProps<RowData extends {}> {
  getRowClassName?: (row: RowData) => string;
  getCellClassName?: (cell: Cell<RowData, unknown>) => string;
  table: ReturnType<typeof useReactTable<RowData>>;
  tableContainerRef: React.RefObject<HTMLDivElement>;
  growColumnIndex?: number;
}

const TableBody = <RowData extends {}>({
  getRowClassName,
  getCellClassName,
  table,
  tableContainerRef,
  growColumnIndex,
}: TableBodyProps<RowData>) => {
  const { rows } = table.getRowModel();

  const rowVirtualizer = useVirtualizer<HTMLDivElement, HTMLTableRowElement>({
    count: rows.length,
    estimateSize: () => baseRowHeight,
    getScrollElement: () => tableContainerRef.current,
    measureElement:
      typeof window !== "undefined" &&
      navigator.userAgent.indexOf("Firefox") === -1
        ? (element) => element?.getBoundingClientRect().height
        : undefined,
    overscan: 10,
  });

  return (
    <tbody
      style={{
        height: `${rowVirtualizer.getTotalSize()}px`,
      }}
    >
      {rowVirtualizer.getVirtualItems().map((virtualRow, index) => {
        const row = rows[virtualRow.index];

        return (
          <TableRow
            key={row.id}
            row={row}
            virtualRow={virtualRow}
            rowVirtualizer={rowVirtualizer}
            getRowClassName={getRowClassName}
            getCellClassName={getCellClassName}
            growColumnIndex={growColumnIndex}
          />
        );
      })}
    </tbody>
  );
};

interface TableRowProps<RowData extends {}> {
  row: Row<RowData>;
  virtualRow: VirtualItem;
  rowVirtualizer: Virtualizer<HTMLDivElement, HTMLTableRowElement>;
  getRowClassName?: (row: RowData) => string;
  getCellClassName?: (cell: Cell<RowData, unknown>) => string;
  growColumnIndex?: number;
}

const TableRow = <RowData extends {}>({
  row,
  virtualRow,
  rowVirtualizer,
  getRowClassName,
  getCellClassName,
  growColumnIndex,
}: TableRowProps<RowData>) => {
  return (
    <tr
      data-index={virtualRow.index}
      ref={(node) => rowVirtualizer.measureElement(node)}
      key={row.id}
      className={classNames(getRowClassName?.(row.original), {
        [styles.evenRow]: virtualRow.index % 2 === 0,
      })}
      style={{
        transform: `translateY(${virtualRow.start}px)`,
      }}
    >
      {row.getVisibleCells().map((cell, cellIndex) => {
        return (
          <td
            key={cell.id}
            className={getCellClassName?.(cell)}
            style={{
              flexGrow: cellIndex === growColumnIndex ? 1 : undefined,
              width: cell.column.getSize(),
            }}
          >
            {flexRender(cell.column.columnDef.cell, cell.getContext())}
          </td>
        );
      })}
    </tr>
  );
};

export default Table;
