import { capitalize, Divider, styled, Typography } from "@mui/material"
import { axisClasses, LineChart, type LineSeriesType } from "@mui/x-charts"
import { useMemo } from "react"
import { dayjsUTC } from "../../utils/helpers"
import { extraColors } from "../../utils/theme"
import type { TabularData } from "./askDB"
import { AskGradientBorder } from "./AskUI"

const ChartContainer = styled(AskGradientBorder)(() => ({
  width: "100%",
  height: 400,
  padding: 12,
}))
export const AskCharts = ({ tabularData }: { tabularData?: TabularData }) => {
  const formattedData = useMemo(() => {
    if (!tabularData?.data.length) {
      return { dataset: [], series: [] }
    }

    // Get all unique location IDs
    const locationIds = [
      ...new Set(tabularData.data.map(d => d.location_osm_id.toString())),
    ]

    // Create dataset with nulls for missing values
    const datasetByDate = tabularData.data.reduce<
      Record<string, Record<string, number | null>>
    >((acc, data) => {
      const date = new Date(data.start_time)
      const dateStr = date.toLocaleDateString()

      // Initialize date entry with nulls for all locations
      if (!acc[dateStr]) {
        acc[dateStr] = {
          date: date.getTime(),
          ...Object.fromEntries(locationIds.map(id => [id, null])),
        }
      }

      // Set actual value
      if (typeof data.value === "number") {
        acc[dateStr][data.location_osm_id.toString()] = data.value
      }

      return acc
    }, {})

    // Convert to array and sort by date
    const dataset = Object.values(datasetByDate).sort(
      (a, b) => (a.date ?? 0) - (b.date ?? 0)
    )

    // Create series config
    const series = locationIds.map(
      locationId =>
        ({
          type: "line" as const,
          label:
            tabularData.data.find(
              d => d.location_osm_id.toString() === locationId
            )?.source_location ?? locationId,
          showMark: false,
          dataKey: locationId,
          valueFormatter: (v: number | null) => v?.toLocaleString() ?? "N/A",
          connectNulls: true,
        }) satisfies LineSeriesType
    )

    return { dataset, series }
  }, [tabularData?.data])
  if (
    !tabularData ||
    !formattedData.dataset.length ||
    !formattedData.series.length
  )
    return null

  const chartName =
    tabularData.type === "vaccination"
      ? `${capitalize(tabularData.type)} Rate (%)`
      : `${capitalize(tabularData.type)} Count`
  return (
    <ChartContainer>
      <Typography variant="body2Bold" textTransform="capitalize">
        {tabularData.diseaseCode} {tabularData.type} {tabularData.unit}
      </Typography>
      <Divider />
      <LineChart
        disableAxisListener
        dataset={formattedData.dataset}
        series={formattedData.series}
        sx={theme => ({
          paddingBottom: "20px",
          [`.${axisClasses.line}`]: {
            stroke: extraColors.disabled,
          },
          [`.${axisClasses.tickLabel}, .${axisClasses.tick}`]: {
            fontWeight: theme.typography.small1Bold.fontWeight,
            fill: extraColors.medium,
          },
          [`.${axisClasses.label}`]: {
            fontSize: theme.typography.small1.fontSize,
            fill: extraColors.medium,
          },
          // Hide the first tick label for y-axis
          [`.${axisClasses.directionY} g:first-of-type`]: {
            display: "none",
          },
        })}
        xAxis={[
          {
            label: "Month/Year",
            scaleType: "time",
            dataKey: "date",
            valueFormatter: (d: number, context) => {
              if (context.location === "tick") {
                return dayjsUTC(d).format("M/YY")
              }
              if (context.location === "tooltip") {
                return dayjsUTC(d).format("M/DD/YY")
              }
              // legend
              return dayjsUTC(d).format("M/DD/YY")
            },
            disableTicks: true,
          },
        ]}
        yAxis={[
          {
            label: chartName,
            disableLine: true,
          },
        ]}
        slotProps={{
          legend: {
            hidden: true,
          },
        }}
        tooltip={{
          trigger: formattedData.series.length < 5 ? "axis" : "item",
        }}
        margin={{ top: 20, right: 20, bottom: 40, left: 40 }}
        experimentalMarkRendering
      />
    </ChartContainer>
  )
}
