import type { Select } from "@mui/joy";
import {
  Box,
  Card,
  Chip,
  Link,
  Modal,
  ModalClose,
  ModalDialog,
  ModalOverflow,
  Stack,
  Switch,
  Typography,
} from "@mui/joy";
import type { ComponentProps } from "react";
import { useState } from "react";
import { toast } from "react-toastify";
import { useOrganization } from "../../lib/api/organization";
import { usePrimaryColor } from "../../lib/hooks/useTheme";
import { useTranslation } from "../../lib/i18n";
import { useUnsafeModelWarning } from "./useUnsafeModelWarning";
import type { LlmName } from "../../../../backend/src/ai/llmMeta";
import { LLM_META, type LlmMetaData } from "../../../../backend/src/ai/llmMeta";
import { ModelIcon } from "../../lib/ModelIcon";
import { trpc } from "../../lib/api/trpc/trpc";

type ModelSelectorProps = {
  selectedModel: LlmName | null;
  setSelectedModel: (model: LlmName) => void;
  allModels?: boolean;
};

type MultiModelSelectorProps = {
  availableModels: LlmName[];
  selectedModels: string[];
  updateModel: (modelKey: string, enabled: boolean) => Promise<boolean>;
  setDefaultModel?: (model: LlmName) => void;
};

function useAvailableModelsForChat(allModels: boolean = false) {
  const { data: availableModels } = trpc.modelConfig.getEnabled.useQuery();
  // this approach filters the meta so the order is preserved
  return (Object.entries(LLM_META)
    .filter(
      ([modelName, meta]) =>
        (allModels || meta.allowChat) &&
        availableModels?.includes(modelName as LlmName)
    )
    .map(([key, meta]) => ({
      ...meta,
      key,
    })) ?? []) as (LlmMetaData & { key: LlmName })[];
}

export function ModelSelectorModal({
  selectedModel,
  setSelectedModel,
  allModels,
  open,
  onRequestClose,
}: ModelSelectorProps &
  Partial<ComponentProps<typeof Select>> & {
    open: boolean;
    onRequestClose: () => void;
  }) {
  const organization = useOrganization();
  const defaultModel = organization?.defaultModel ?? "";
  const options = useAvailableModelsForChat(allModels);

  const nonEuWarningSkippable = organization?.nonEuWarningSkippable ?? false;

  const unsafeModelWarning = useUnsafeModelWarning(
    setSelectedModel,
    () => {
      onRequestClose();
    },
    nonEuWarningSkippable
  );
  const renderModal = unsafeModelWarning.renderModal;

  const { t } = useTranslation();

  return (
    <>
      {renderModal()}
      <Modal open={open} onClose={onRequestClose}>
        <ModalOverflow>
          <ModalDialog
            aria-labelledby="model-dialog-overflow"
            sx={{ width: "70%", maxWidth: "lg", py: 5 }}
          >
            <ModalClose />
            <Typography level="h1" width="100%" textAlign="center">
              {t("chatSettings")}
            </Typography>
            <Typography level="h4" sx={{ mb: 3 }} textAlign="center">
              {t("selectModel")}
            </Typography>
            <div
              className="grid gap-8"
              style={{
                gridTemplateColumns: "repeat(auto-fill, minmax(350px, 1fr))",
              }}
            >
              {options.map((option) => (
                <ModelCard
                  key={option.key}
                  defaultModel={defaultModel}
                  selected={
                    selectedModel === option.key ||
                    ((!selectedModel ||
                      !options.some(
                        (innerOption) => innerOption.key === selectedModel
                      )) &&
                      defaultModel === option.key)
                  }
                  setSelectedModel={setSelectedModel}
                  unsafeModelWarning={unsafeModelWarning}
                  onCloseModal={onRequestClose}
                  keyName={option.key}
                />
              ))}
            </div>
          </ModalDialog>
        </ModalOverflow>
      </Modal>
    </>
  );
}

export function ModelCard({
  selected = true,
  defaultModel,
  setDefaultModel,
  setSelectedModel,
  unsafeModelWarning,
  onCloseModal,
  keyName,
}: {
  selected: boolean;
  defaultModel: string;
  setDefaultModel?: (model: LlmName) => void;
  setSelectedModel: (model: LlmName) => void;
  unsafeModelWarning: ReturnType<typeof useUnsafeModelWarning>;
  onCloseModal?: () => void;
  keyName: LlmName;
}) {
  const { t } = useTranslation();
  const {
    name,
    provider,
    quality,
    speed,
    infoUrl,
    capabilities,
    hostingLocation,
  } = LLM_META[keyName] ?? {};
  const primaryColor = usePrimaryColor();
  const block = (filled: boolean) => (
    <div
      style={{
        border: "1px solid gray ",
        backgroundColor: filled ? primaryColor : "none",
        width: "25px",
        height: "10px",
        display: "inline-block",
        borderRadius: "3px",
      }}
    />
  );
  const isDefault = keyName === defaultModel;
  const [isHovered, setIsHovered] = useState(false);

  //setDefaultModel is only available for General Settings page

  const backgroundColor =
    setDefaultModel && !selected ? "rgb(235,235,235)" : "white";

  const hostedInEU = hostingLocation === "EU";

  return (
    <Card
      onMouseEnter={() => setIsHovered(true)}
      onMouseLeave={() => setIsHovered(false)}
      sx={{
        background: backgroundColor,
        cursor: setDefaultModel ? "default" : "pointer",
        ml: 0,
        height: "350px",
        "&:hover": {
          background: setDefaultModel ? backgroundColor : "rgb(245,245,245)",
        },
        outline:
          !setDefaultModel && selected ? "3px solid " + primaryColor : "none",
      }}
      onClick={() => {
        if (!setDefaultModel) {
          if (hostedInEU) {
            setSelectedModel(keyName as LlmName);
            onCloseModal && onCloseModal();
          } else {
            unsafeModelWarning.onChooseUnsafeModel(keyName as LlmName);
          }
        }
      }}
    >
      {setDefaultModel && (
        <Stack position="absolute" right={10} gap={0.5} alignItems="flex-end">
          <Switch
            checked={selected}
            size="lg"
            sx={{ mr: 0.5, alignSelf: "end" }}
            onClick={() => {
              if (selected || hostedInEU) {
                setSelectedModel(keyName as LlmName);
                onCloseModal && onCloseModal();
              } else {
                unsafeModelWarning.onChooseUnsafeModel(keyName as LlmName);
              }
            }}
          />
          {isDefault ? (
            <Chip color="primary">{t("default")}</Chip>
          ) : (
            isHovered && (
              <Link
                color="neutral"
                level="body-sm"
                onClick={(e) => {
                  e.stopPropagation();
                  setDefaultModel && setDefaultModel(keyName);
                  !selected && setSelectedModel(keyName);
                }}
              >
                {t("default")}
              </Link>
            )
          )}
        </Stack>
      )}
      <Box>
        <Stack direction="row" gap={2}>
          <ModelIcon
            modelName={keyName}
            style={{ height: "60px", width: "auto", marginBottom: "10px" }}
          />
          <Stack>
            <Typography level="h4">{name}</Typography>
            <Typography level="body-md">
              {provider} ({hostingLocation == "EU" ? "🇪🇺" : "🇺🇸"})
            </Typography>
          </Stack>
        </Stack>

        <Stack direction="row" gap={2}>
          <Stack>
            <Typography level="body-md" sx={{ width: "100%" }}>
              {t("textQuality")}:
            </Typography>
            <Typography level="body-md" sx={{ width: "100%" }}>
              {t("speed")}:
            </Typography>{" "}
          </Stack>
          <Stack justifyContent="space-evenly">
            <Stack direction="row" gap={1}>
              {Array.from({ length: 5 }, (_, i) => block(i < quality))}
            </Stack>
            <Stack direction="row" gap={1}>
              {Array.from({ length: 5 }, (_, i) => block(i < speed))}
            </Stack>
          </Stack>
        </Stack>
        <Typography level="body-sm" my={2} width="100%">
          {t("modelMeta." + keyName)}
        </Typography>
        <Stack direction="row" gap={1} flexWrap="wrap">
          {capabilities?.map((capability) => (
            <Chip key={capability}>{capability}</Chip>
          ))}
        </Stack>
        <Link
          color="neutral"
          onClick={(e) => {
            e.stopPropagation();
          }}
          href={infoUrl}
          target="_blank"
          level="body-xs"
          sx={{
            position: "absolute",
            bottom: 5,
            right: 5,
          }}
        >
          {t("moreInfo")}
        </Link>
      </Box>
    </Card>
  );
}

export function EnabledModelsSelector({
  availableModels,
  selectedModels,
  updateModel,
  setDefaultModel,
}: MultiModelSelectorProps) {
  const organization = useOrganization();
  const defaultModel = organization?.defaultModel ?? "";

  const selectModel = (key) => {
    updateModel(key, !selectedModels.includes(key)).catch(() => {
      toast.error("errorDisplay.title");
    });
  };

  const nonEuWarningSkippable = organization?.nonEuWarningSkippable ?? false;

  const unsafeModelWarning = useUnsafeModelWarning(
    selectModel,
    () => {},
    nonEuWarningSkippable
  );
  return (
    <div
      className="grid gap-8"
      style={{ gridTemplateColumns: "repeat(auto-fill, minmax(350px, 1fr))" }}
    >
      {unsafeModelWarning.renderModal()}
      {availableModels.map((key) => (
        <ModelCard
          selected={selectedModels.includes(key)}
          defaultModel={defaultModel}
          setSelectedModel={selectModel}
          unsafeModelWarning={unsafeModelWarning}
          key={key}
          keyName={key}
          setDefaultModel={setDefaultModel}
        />
      ))}
    </div>
  );
}
