import {
  ColumnDef,
  flexRender,
  getCoreRowModel,
  getPaginationRowModel,
  getSortedRowModel,
  getFilteredRowModel,
  useReactTable,
  SortingState,
  ColumnFiltersState,
  VisibilityState
} from '@tanstack/react-table';
import { Fragment, useState } from 'react';
import { FilledButton } from './Buttons';
import { cn } from 'shared/lib';

interface ColumnMeta {
  enableColumnFilter?: boolean;
  isCollapsible?: boolean;
}

interface DataTableProps<TData> {
  data: TData[];
  columns: ColumnDef<TData, unknown>[];
  enablePagination?: boolean;
  getRowClassName?: (row: TData) => string;
}

export function DataTable<TData>({
  data,
  columns,
  getRowClassName,
  enablePagination = false
}: DataTableProps<TData>) {
  const [sortingState, setSortingState] = useState<SortingState>([]);
  const [columnFilters, setColumnFilters] = useState<ColumnFiltersState>([]);
  const [expandedRows, setExpandedRows] = useState<Set<string>>(new Set());

  const collapsibleColumns = columns
    .filter((col) => (col.meta as ColumnMeta)?.isCollapsible)
    .map((col) => col.id as string);

  const [columnVisibility, setColumnVisibility] = useState<VisibilityState>(
    Object.fromEntries(collapsibleColumns.map((col) => [col, false]))
  );

  const table = useReactTable({
    data,
    columns,
    getCoreRowModel: getCoreRowModel(),
    getSortedRowModel: getSortedRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    onSortingChange: setSortingState,
    onColumnFiltersChange: setColumnFilters,
    onColumnVisibilityChange: setColumnVisibility,
    state: { sorting: sortingState, columnFilters, columnVisibility },
    ...(enablePagination
      ? { getPaginationRowModel: getPaginationRowModel() }
      : {})
  });

  const hiddenColumns = collapsibleColumns.filter(
    (col) => columnVisibility[col] === false
  );

  return (
    <div className="w-full overflow-x-auto">
      <table className="min-w-full border border-gray-300">
        <thead className="bg-gray-100">
          {table.getHeaderGroups().map((headerGroup) => (
            <tr key={headerGroup.id} className="border-b">
              {headerGroup.headers.map((header) => (
                <th
                  key={header.id}
                  className="px-4 py-2 text-left cursor-pointer select-none"
                  onClick={header.column.getToggleSortingHandler()}
                >
                  {flexRender(
                    header.column.columnDef.header,
                    header.getContext()
                  )}
                  {header.column.getIsSorted() === 'asc'
                    ? ' ▲'
                    : header.column.getIsSorted() === 'desc'
                    ? ' ▼'
                    : ''}
                </th>
              ))}
              {hiddenColumns.length > 0 && <th className="px-4 py-2"></th>}
            </tr>
          ))}
        </thead>
        <tbody>
          {table.getRowModel().rows.map((row) => (
            <Fragment key={row.id}>
              <tr
                className={cn(
                  `border-b`,
                  getRowClassName ? getRowClassName(row.original) : ''
                )}
              >
                {row.getVisibleCells().map((cell) => (
                  <td key={cell.id} className="px-4 py-2">
                    {flexRender(cell.column.columnDef.cell, cell.getContext())}
                  </td>
                ))}
                {hiddenColumns.length > 0 && (
                  <td className="px-4 py-2 text-center">
                    <button
                      className="p-1 bg-gray-200 rounded"
                      onClick={() => {
                        const newExpandedRows = new Set(expandedRows);
                        if (newExpandedRows.has(row.id)) {
                          newExpandedRows.delete(row.id);
                        } else {
                          newExpandedRows.add(row.id);
                        }
                        setExpandedRows(newExpandedRows);
                      }}
                    >
                      ...
                    </button>
                  </td>
                )}
              </tr>
              {expandedRows.has(row.id) && (
                <tr className="border-b bg-gray-50">
                  <td
                    colSpan={table.getAllColumns().length + 1}
                    className="px-4 py-2"
                  >
                    <div className="flex gap-10 ag-items-end">
                      {hiddenColumns
                        .map((col) =>
                          row.getAllCells().find((c) => c.column.id === col)
                        )
                        .filter(Boolean)
                        .map((cell) => (
                          <div key={cell!.column.id}>
                            <div className="ag-font-medium">
                              {cell!.column.columnDef.header as string}
                            </div>
                            <div>
                              {flexRender(
                                cell!.column.columnDef.cell,
                                cell!.getContext()
                              )}
                            </div>
                          </div>
                        ))}
                    </div>
                  </td>
                </tr>
              )}
            </Fragment>
          ))}
        </tbody>
      </table>
      {enablePagination && (
        <div className="flex justify-between items-center mt-4">
          <FilledButton
            onClick={() => table.previousPage()}
            disabled={!table.getCanPreviousPage()}
          >
            Back
          </FilledButton>
          <span>
            Page {table.getState().pagination.pageIndex + 1} of{' '}
            {table.getPageCount()}
          </span>
          <FilledButton
            onClick={() => table.nextPage()}
            disabled={!table.getCanNextPage()}
          >
            Next
          </FilledButton>
        </div>
      )}
    </div>
  );
}
