import { capitalize, Divider, styled, Typography } from "@mui/material"
import {
  axisClasses,
  ChartContainerPro,
  ChartsAxisHighlight,
  ChartsClipPath,
  ChartsTooltipContainer,
  ChartsXAxis,
  ChartsYAxis,
  LineHighlightPlot,
  LinePlot,
  type LineSeriesType,
} from "@mui/x-charts-pro"
import {} from "@mui/x-charts-pro/typeOverloads"
import { useId, useMemo, useState } from "react"
import { dayjsUTC } from "../../utils/helpers"
import { extraColors } from "../../utils/theme"
import { AskChartTooltip, type AskHighlightedItem } from "./AskChartTooltip"
import type { TabularData } from "./askDB"
import { AskGradientBorder } from "./AskUI"

const ChartContainerLayout = styled(AskGradientBorder)(() => ({
  width: "100%",
  height: 400,
  padding: 12,
}))
export const AskCharts = ({ tabularData }: { tabularData?: TabularData }) => {
  const chartId = useId()
  const clipPathId = `${chartId}-clip-path`
  const [highlightedItem, setHighlightedItem] =
    useState<AskHighlightedItem | null>(null)
  const formattedData = useMemo(() => {
    if (!tabularData?.data.length) {
      return { dataset: [], series: [] }
    }

    // Get all unique location IDs
    const osmIds = [
      ...new Set(tabularData.data.map(d => d.location_osm_id.toString())),
    ]
    type TabDate = string
    type OsmId = string
    // Create dataset with nulls for missing values
    const datasetByDate = tabularData.data.reduce<
      Record<TabDate, Record<OsmId, number | null>>
    >((acc, data) => {
      const date = new Date(data.start_time)
      const dateStr = date.toLocaleDateString()

      // Initialize date entry with nulls for all locations
      if (!acc[dateStr]) {
        acc[dateStr] = {
          date: date.getTime(),
          ...Object.fromEntries(osmIds.map(id => [id, null])),
        }
      }

      // Set actual value
      if (typeof data.value === "number") {
        acc[dateStr][data.location_osm_id.toString()] = data.value
      }

      return acc
    }, {})

    // Convert to array and sort by date
    const dataset = Object.values(datasetByDate).sort(
      (a, b) => (a.date ?? 0) - (b.date ?? 0)
    )

    // Create series config
    const series = osmIds
      .map(
        osmId =>
          ({
            type: "line" as const,
            label:
              tabularData.data.find(d => d.location_osm_id.toString() === osmId)
                ?.source_location ?? osmId,
            showMark: false,
            dataKey: osmId,
            valueFormatter: (v: number | null) => v?.toLocaleString() ?? "N/A",
            connectNulls: true,
            id: osmId,
            data: dataset.map(d => d[osmId] ?? null),
            highlightScope: {
              fade: "global",
              highlight: "series",
            },
          }) satisfies LineSeriesType
      )
      // filter out any series that doesn't have 2 data points that aren't null
      .filter(s => s.data.filter(d => d !== null).length > 1)

    return { dataset, series }
  }, [tabularData?.data])
  if (
    !tabularData ||
    !formattedData.dataset.length ||
    !formattedData.series.length
  )
    return null

  const chartName =
    tabularData.type === "vaccination"
      ? `${capitalize(tabularData.type)} Rate (%)`
      : `${capitalize(tabularData.type)} Count`
  return (
    <ChartContainerLayout>
      <Typography variant="body2Bold" textTransform="capitalize">
        {tabularData.diseaseCode} {tabularData.type} {tabularData.unit}
      </Typography>
      <Divider />
      <ChartContainerPro
        dataset={formattedData.dataset}
        series={formattedData.series}
        xAxis={[
          {
            label: "Month/Year",
            scaleType: "time",
            dataKey: "date",
            valueFormatter: (d: number) => {
              return dayjsUTC(d).format("M/DD/YY")
            },
            disableTicks: true,
            zoom: true,
          },
        ]}
        yAxis={[
          {
            label: chartName,
            disableLine: true,
          },
        ]}
        sx={theme => ({
          [`.${axisClasses.line}`]: {
            stroke: extraColors.disabled,
          },
          [`.${axisClasses.tickLabel}, .${axisClasses.tick}`]: {
            fontWeight: theme.typography.small1Bold.fontWeight,
            fill: extraColors.medium,
          },
          [`.${axisClasses.label}`]: {
            fontSize: theme.typography.small1.fontSize,
            fill: extraColors.medium,
          },
          // Hide the first tick label for y-axis
          [`.${axisClasses.directionY} g:first-of-type`]: {
            display: "none",
          },
          [`.${axisClasses.left} .${axisClasses.label}`]: {
            transform: "translateX(-20px)",
          },
        })}
        margin={{ top: 20, right: 20, bottom: 60, left: 60 }}
        highlightedItem={highlightedItem}
      >
        <g clipPath={`url(#${clipPathId})`}>
          <LinePlot />
        </g>
        <LineHighlightPlot />
        <ChartsXAxis />
        <ChartsYAxis />
        <ChartsTooltipContainer trigger="axis">
          <AskChartTooltip
            setHighlightedItem={setHighlightedItem}
            highLightedItem={highlightedItem}
          />
        </ChartsTooltipContainer>
        <ChartsAxisHighlight x="line" />
        <ChartsClipPath id={clipPathId} />
      </ChartContainerPro>
    </ChartContainerLayout>
  )
}
