import { useMemo, useState } from "react";
import { useSuspenseQuery } from "@tanstack/react-query";
import { Column } from "@tanstack/react-table";
import { useParams } from "react-router-dom";
import { toast } from "sonner";
import { match, P } from "ts-pattern";

import { getQueryOptions as getNumLabelsOpts } from "@/app/dashboard/submissions/:submissionId/(stack-lists)/NumLabelsToReview";
import { createManyFeedbackMutation } from "@/app/dashboard/training/useCreateManyFeedback";
import { FeedbackCreateManyInput } from "@/gql/graphql";
import { CONFIDENCE_INTERVALS } from "@/lib/constants/confidenceIntervals";
import { createQueryOptions, useGraphqlMutation } from "@/lib/hooks/graphql";
import { queryClient } from "@/queryClient";

import { PredictionsTable } from "../../../(predictions)/PredictionsTable";
import {
  createPredictionColumnDefs,
  PredictionColumnDef,
  predictionRowsQuery,
} from "../../../(predictions)/PredictionTableColumns";
import { useTransformationPredictionIds } from "./useTransformationPredictionIds";

export const getQueryOptions = (transformationId: number) => {
  return createQueryOptions({
    query: predictionRowsQuery,
    variables: {
      where: {
        transformationId: { equals: transformationId },
        // labels: {
        //   some: {
        //     confidence: {
        //       gte: CONFIDENCE_INTERVALS.Low[0],
        //       lt: CONFIDENCE_INTERVALS.Low[1],
        //     },
        //   },
        // },
      },
    },
  });
};

export const HIDDEN_COLS = new Set([
  "Transformation Type",
  "Row Index",
  "Transformation Id",
  "Vendor Id",
  "Created At",
]);

export function TransformationPredictionsTable() {
  const { transformationId } = useParams();

  const [showFilterState, setShowFilterState] = useState(
    ([] as Column<PredictionColumnDef>[]).reduce(
      (acc, col) => {
        acc[col.id] = false;
        return acc;
      },
      {} as Record<string, boolean>,
    ),
  );
  const options = useMemo(
    () => getQueryOptions(Number(transformationId)),
    [transformationId],
  );
  const { data: predictionIds } = useTransformationPredictionIds(
    Number(transformationId),
  );
  const numRevOptions = getNumLabelsOpts({
    predictionIds,
    interval: CONFIDENCE_INTERVALS.Low,
  });

  const { data: predictionRowData } = useSuspenseQuery(options);

  const { queryKey } = options;
  const { predictions } = predictionRowData;

  const { mutate } = useGraphqlMutation({
    mutation: createManyFeedbackMutation,
    onMutate: async ({ data: mutationData }) => {
      await queryClient.cancelQueries({ queryKey });

      const previousData =
        queryClient.getQueryData<typeof predictionRowData>(queryKey);

      const updateOptimistic = (feedback: FeedbackCreateManyInput[]) => {
        queryClient.setQueryData<typeof predictionRowData>(
          queryKey,
          (_previousData) => {
            if (!_previousData) return undefined;

            // flat map of all labels
            const optimisticLabels = _previousData.predictions.flatMap((p) =>
              p.labels.map((label) => ({
                ...label,
                predictionId: p.id,
              })),
            );

            // update relevant labels with feedback
            feedback.forEach((f, i) => {
              const label = optimisticLabels.find((l) => l.id === f.labelId);
              if (!label) return;

              label.feedback = [
                {
                  ...f,
                  id: (i + 1) * -1,
                  upvote: !!f.upvote,
                },
                ...label.feedback,
              ];
            });

            // update predictions
            const optimisticPredictions = _previousData.predictions.map((p) => {
              const relevantLabels = optimisticLabels.filter(
                (l) => l.predictionId === p.id,
              );

              return {
                ...p,
                labels: relevantLabels,
              };
            });

            return {
              predictions: optimisticPredictions,
            };
          },
        );
      };

      match(mutationData)
        .with(P.array(P.any), updateOptimistic)
        .with(P.any, (feedback) => updateOptimistic([feedback]))
        .exhaustive();

      return { previousData };
    },
    onSuccess: () => {
      queryClient.invalidateQueries({ queryKey });
      queryClient.invalidateQueries({ queryKey: numRevOptions.queryKey });
    },
    onError(error, variables, context) {
      toast.error("Error submitting labels");
      // @ts-ignore - https://github.com/TanStack/query/discussions/3434#discussioncomment-2425225
      queryClient.setQueryData(queryKey, context.previousData);
    },
  });

  const columns = createPredictionColumnDefs({
    predictions,
    setShowFilterState,
  });

  const colsVisibleState = (columns as Column<PredictionColumnDef>[]).reduce(
    (acc, col) => {
      acc[col.id] = !HIDDEN_COLS.has(col.id);
      return acc;
    },
    {} as Record<string, boolean>,
  );

  return (
    <PredictionsTable
      key={transformationId}
      columns={columns}
      columnVisibility={colsVisibleState}
      predictions={predictions}
      className="p-0"
      showFilterState={showFilterState}
      onSubmitFeedback={(feedback) => {
        mutate({
          data: feedback,
        });
      }}
    />
  );
}
