import {
    GRID_CHECKBOX_SELECTION_COL_DEF,
    GRID_DETAIL_PANEL_TOGGLE_FIELD,
    gridExpandedSortedRowIdsSelector,
    GridRowSelectionModel,
} from '@mui/x-data-grid-premium';
import { GridApiPremium } from '@mui/x-data-grid-premium/models/gridApiPremium';
import { MutableRefObject, useCallback, useEffect, useRef } from 'react';
import { DataGridColDef, DataGridRowId, DataGridRowSelectionType, DataGridRowType } from './baseDataGridTypes';

export const useBaseDataGridRowSelection = <T extends DataGridRowType>(
    gridApiRef: MutableRefObject<GridApiPremium>,
    rowSelection: DataGridRowSelectionType,
    rowSelectionState: string[] | undefined,
    rowSelectionOnClick: boolean,
    onRowSelectionChange: ((rowIds: DataGridRowId[]) => void) | undefined
) => {
    // holds the last selected row id (for range selection)
    const lastSelectedRowId = useRef<DataGridRowId | undefined>(undefined);

    const selectedRowIdsRef = useRef<DataGridRowId[]>(rowSelectionState ?? []);

    // handles the row selection on click
    // taken from https://github.com/mui/mui-x/blob/bdd82aeac9faf01ad68ed2e8f3258c39fc29ee49/packages/x-data-grid/src/hooks/features/rowSelection/useGridRowSelection.ts#L549
    // and adapted to our needs
    useEffect(() => {
        const unsubscribeRowClick = gridApiRef.current.subscribeEvent('rowClick', ({ id }, event) => {
            if (!rowSelectionOnClick || rowSelection === 'none') {
                return;
            }

            const field = (event.target as HTMLDivElement).closest(`.MuiDataGrid-cell`)?.getAttribute('data-field');

            if (field === GRID_CHECKBOX_SELECTION_COL_DEF.field) {
                return; // click on checkbox should not trigger row selection
            }

            if (field === GRID_DETAIL_PANEL_TOGGLE_FIELD) {
                return; // click to open the detail panel should not select the row
            }

            if (field) {
                const column = gridApiRef.current.getColumn(field) as unknown as DataGridColDef<T> | undefined;

                if (column?._columnType === 'action') {
                    return; // click on a action column should not select the row
                }
            }

            const rowNode = gridApiRef.current.getRowNode(id);
            if (rowNode!.type === 'pinnedRow') {
                return; // click on a pinned column should not select the row
            }

            if (rowSelection === 'multiple' && event.shiftKey) {
                // select range

                let endId = id;
                const startId = lastSelectedRowId.current ?? id;
                const isSelected = gridApiRef.current.isRowSelected(id);
                if (isSelected) {
                    const visibleRowIds = gridExpandedSortedRowIdsSelector(gridApiRef);
                    const startIndex = visibleRowIds.findIndex((rowId) => rowId === startId);
                    const endIndex = visibleRowIds.findIndex((rowId) => rowId === endId);
                    if (startIndex === endIndex) {
                        return;
                    }
                    if (startIndex > endIndex) {
                        endId = visibleRowIds[endIndex + 1];
                    } else {
                        endId = visibleRowIds[endIndex - 1];
                    }
                }

                gridApiRef.current.selectRowRange({ startId, endId }, !isSelected);
            } else if (rowSelection === 'multiple' && event.ctrlKey) {
                // append to current selection
                const isSelected = gridApiRef.current.isRowSelected(id);
                gridApiRef.current.selectRow(id, !isSelected, false);
            } else {
                // single row selection
                const isSelected = gridApiRef.current.isRowSelected(id);
                gridApiRef.current.selectRow(id, !isSelected, true);
            }

            // remember the latest row id (for range selection)
            lastSelectedRowId.current = id.toString();
        });

        const unsubscribeRowSelectionChange = gridApiRef.current.subscribeEvent('rowSelectionChange', (selectionModel) => {
            const lastId = selectionModel[selectionModel.length - 1];
            if (!lastId) return;
            lastSelectedRowId.current = lastId.toString(); // remember the latest row id (for range selection)
        });

        return () => {
            unsubscribeRowClick();
            unsubscribeRowSelectionChange();
        };
    }, [gridApiRef, rowSelection, rowSelectionOnClick]);

    /** make sure the row selection change is triggered only when the row selection really changes */
    const handleRowSelectionModelChange = useCallback(
        (rowSelectionModel: GridRowSelectionModel) => {
            if (
                selectedRowIdsRef.current.length !== rowSelectionModel.length ||
                !rowSelectionModel.every((id) => selectedRowIdsRef.current.includes(id as DataGridRowId))
            ) {
                selectedRowIdsRef.current = [...rowSelectionModel] as DataGridRowId[];
                onRowSelectionChange && onRowSelectionChange([...selectedRowIdsRef.current]);
            }
        },
        [onRowSelectionChange]
    );

    return { handleRowSelectionModelChange };
};
