import React, { Fragment, useMemo, useState } from 'react';
import { isArray, isNumber, last, max, uniq } from 'lodash';

import { ScaleLinear, ScaleOrdinal, ScaleTime } from 'd3-scale';

import { ParentSize } from '@visx/responsive';
import { AreaClosed, AreaStack, Bar, Line, LinePath } from '@visx/shape';
import { Grid, GridColumns } from '@visx/grid';
import { AxisBottom, AxisLeft, AxisRight } from '@visx/axis';
import { scaleLinear, scaleTime, scaleOrdinal } from '@visx/scale';
import { LinearGradient } from '@visx/gradient';
import { Group } from '@visx/group';
import { localPoint } from '@visx/event';
import { curveMonotoneX } from '@visx/curve';
import { getStringWidth } from '@visx/text';

import { formatNumberToCompact } from '@/lib/utils/number';
import { findClosestDate, getFormattedDateForPeriod } from '@/lib/date';
import { CheckboxWithLabel } from '@/components/form/checkbox';
import { Grouping, TimeAggregationPeriod } from '@/explore/types';

import { ColoredLegend } from '../legend';
import { ChartTooltip, DefaultChartHeight } from '../common';
import { getHighlightedTickValues, getTickFormatFn, getTickValues } from './utils';
import { getMaxValue, getMinValue, getSeries, getValueKeys } from '../grouped-chart/utils';
import { GroupedChartData, TimeSeriesData } from '../grouped-chart/types';

import { GroupedChart } from '../grouped-chart';

import commonStyles, {
  gridLineColor,
  gridHighlightedLineColor,
  hoverLineColor,
  barColor1,
  barColor1Accent,
  barColor2,
  barColor2Accent,
  barColor3,
  barColor3Accent,
  barColor4,
  barColor4Accent,
  barColor5,
  barColor5Accent,
  barColor6,
  barColor6Accent,
} from '../charts.module.scss';

const NumYTicks = 7;
const SeriesDepth = 2;
const ChartHeight = DefaultChartHeight;
const WidthPerGroup = parseInt(commonStyles.widthPerGroup);
const YValuePaddingCoef = 0.1;
const HorizontalPadding = 20;
const VerticalPadding = 10;

type StackedChartData = {
  [key: string]: number | Date;
  date: Date;
}[];

const generateStackedChartData = (
  data: TimeSeriesData,
): {
  left: StackedChartData;
  right: StackedChartData;
} => {
  const leftKeys = data.series.filter(({ axis }) => axis === 'left').map(({ key }) => key);
  const rightKeys = data.series.filter(({ axis }) => axis === 'right').map(({ key }) => key);
  const allValues: StackedChartData = data.items.map((item) => {
    const date = item.dateValue;
    return {
      date,
      ...item.values,
    };
  });
  return {
    left: allValues.map((item) => ({
      date: item.date,
      ...leftKeys.reduce((acc, key) => ({ ...acc, [key]: item[key] }), {}),
    })),
    right: allValues.map((item) => ({
      date: item.date,
      ...rightKeys.reduce((acc, key) => ({ ...acc, [key]: item[key] }), {}),
    })),
  };
};

interface ChartColor {
  fill: string;
  highlight: string;
}

interface TimeSeriesChartProps {
  width: number;
  height: number;
  data: TimeSeriesData;
  dateScale: ScaleTime<number, number>;
  valueScales: {
    left: ScaleLinear<number, number>;
    right: ScaleLinear<number, number>;
  };
  colorScale: ScaleOrdinal<string, ChartColor>;
  aggPeriod: TimeAggregationPeriod;
  highlightedDate: Date | null;
  onMouseMove?: (
    event: React.TouchEvent<SVGRectElement> | React.MouseEvent<SVGRectElement>,
  ) => void;
  onMouseLeave?: () => void;
  stacked: boolean;
  hideGrid: boolean;
  timezone: string;
}

const TimeSeriesChart = (props: TimeSeriesChartProps) => {
  const {
    width,
    height,
    data,
    dateScale,
    valueScales,
    colorScale,
    aggPeriod,
    highlightedDate,
    onMouseMove,
    onMouseLeave,
    stacked,
    hideGrid,
    timezone,
  } = props;

  const tooltipData: TooltipData | null = useMemo(() => {
    if (highlightedDate === null) {
      return null;
    }

    const closestDataPoint = data.items.find(
      ({ dateValue }) => dateValue.getTime() === highlightedDate.getTime(),
    );

    if (closestDataPoint === undefined) {
      return null;
    }

    const points = data.series.map(({ key, label, axis }) => ({
      label,
      value: closestDataPoint.values[key] ?? 'N/A',
      date: closestDataPoint.dateValue,
      color: colorScale(key).fill,
      axis,
    }));

    return {
      date: highlightedDate,
      timezone,
      aggPeriod,
      points,
    };
  }, [highlightedDate, data, timezone, aggPeriod, colorScale]);

  const yAxisWidths = useMemo(
    () =>
      Object.entries(valueScales).reduce(
        (acc, [key, valueScale]) => {
          return {
            ...acc,
            [key]:
              max(
                valueScale
                  .ticks(NumYTicks)
                  .map((i) => formatNumberToCompact({ num: i.valueOf() }))
                  .map((i) => getStringWidth(i)),
              ) ?? 0,
          };
        },
        { left: 0, right: 0 },
      ),
    [valueScales],
  );

  const renderArea = () => {
    if (stacked) {
      return (['left', 'right'] as const).map((axis) => {
        const axisData = stackedData[axis];
        if (domainSizes[axis] <= 0) {
          return null;
        }
        return (
          <AreaStack
            key={axis}
            keys={data.series.map(({ key }) => key)}
            width={width}
            height={height}
            data={axisData}
            order="reverse"
            x={(d) => dateScale(d.data.date)}
            y0={(d) => valueScales[axis](d[0])}
            y1={(d) => valueScales[axis](d[1]) ?? height}
            curve={curveMonotoneX}>
            {({ stacks, path }) =>
              stacks.map((stack) => {
                const color = colorScale(stack.key);
                return (
                  <Fragment key={`stack-${stack.key}`}>
                    <path
                      d={path(stack) ?? ''}
                      fill={`url(#gradient-${color.fill})`}
                      fillOpacity={1}
                    />
                    <LinePath
                      data={stack.map((item) => ({
                        date: item.data.date,
                        value: item.at(1) ?? 0,
                      }))}
                      x={({ date }) => dateScale(date) ?? 0}
                      y={({ value }) => valueScales[axis](value) ?? 0}
                      stroke={color.highlight}
                      strokeOpacity={1}
                      curve={curveMonotoneX}
                    />
                  </Fragment>
                );
              })
            }
          </AreaStack>
        );
      });
    }

    return data.series.map(({ key, axis }) => {
      const color = colorScale(key);
      const seriesData = data.items.map((item) => ({
        date: item.dateValue,
        value: item.values[key],
      }));
      return (
        <Fragment key={key}>
          <AreaClosed
            data={seriesData}
            x={({ date }) => dateScale(date) ?? 0}
            y0={({ value }) => valueScales[axis](value) ?? 0}
            y1={() => valueScales[axis](0)}
            yScale={valueScales[axis]}
            fill={`url(#gradient-${color.fill})`}
            fillOpacity={1}
            curve={curveMonotoneX}
          />
          <LinePath
            data={seriesData}
            x={({ date }) => dateScale(date) ?? 0}
            y={({ value }) => valueScales[axis](value) ?? 0}
            stroke={color.highlight}
            strokeOpacity={1}
            curve={curveMonotoneX}
          />
        </Fragment>
      );
    });
  };

  const keys = data.series.map(({ key }) => key);
  const stackedData = generateStackedChartData(data);

  if (stacked) {
    (['left', 'right'] as const).forEach((axis) => {
      const axisData = stackedData[axis];
      const maxStackedValue = Math.max(
        ...axisData
          .map(
            (item) =>
              Object.values(item)
                .filter(isNumber)
                .reduce((sum, value) => sum + value, 0),
            0,
          )
          .flat(),
      );
      const minStackedValue = Math.min(
        0,
        ...axisData.map((item) => Math.min(...Object.values(item).filter(isNumber)), 0),
      );

      const valuePadding =
        (Math.abs(minStackedValue) + Math.abs(maxStackedValue)) * YValuePaddingCoef;
      const minValuePadding = minStackedValue < 0 ? minStackedValue - valuePadding : 0;

      valueScales[axis].domain([minValuePadding, maxStackedValue + valuePadding]);
    });
  }
  const domainSizes = {
    left: valueScales.left.domain()[1] - valueScales.left.domain()[0],
    right: valueScales.right.domain()[1] - valueScales.right.domain()[0],
  };

  // Default to right as main axis in case no series on left axis
  const primaryValueScale = domainSizes.left > 0 ? valueScales.left : valueScales.right;

  // Force right axis ticks to align with primary axis ticks
  const rightAxisTicks = primaryValueScale
    .ticks(NumYTicks)
    .map((tick) => valueScales.right.invert(primaryValueScale(tick)));

  return (
    <div className={commonStyles.graph}>
      <svg width={width} height={ChartHeight}>
        {hideGrid ? null : (
          <>
            <Grid
              width={width}
              height={ChartHeight}
              xScale={dateScale}
              yScale={primaryValueScale}
              rowTickValues={primaryValueScale.ticks(NumYTicks)}
              stroke={gridLineColor}
              columnTickValues={getTickValues(aggPeriod, dateScale)}
            />
            <GridColumns
              width={width}
              height={ChartHeight}
              scale={dateScale}
              stroke={gridHighlightedLineColor}
              tickValues={getHighlightedTickValues(aggPeriod, dateScale)}
            />
          </>
        )}
        {keys.map((key) => (
          <LinearGradient
            key={key}
            id={`gradient-${colorScale(key).fill}`}
            from={colorScale(key).fill}
            to={colorScale(key).fill}
            fromOpacity={0.8}
            toOpacity={0.2}
          />
        ))}
        {renderArea()}
        {hideGrid ? null : (
          <>
            {domainSizes.left > 0 && (
              <AxisLeft
                left={yAxisWidths.left + 4}
                scale={valueScales.left}
                hideTicks
                hideAxisLine
                tickClassName={commonStyles.tickLabel}
                tickLabelProps={() => ({ dx: 4, dy: 10, textAnchor: 'end' })}
                tickValues={valueScales.left.ticks(NumYTicks)}
                tickFormat={(n) => formatNumberToCompact({ num: n.valueOf() })}
              />
            )}
            {domainSizes.right > 0 && (
              <AxisRight
                left={width - yAxisWidths.right - 4}
                scale={valueScales.right}
                hideTicks
                hideAxisLine
                tickClassName={commonStyles.tickLabel}
                tickLabelProps={() => ({ dx: 7, dy: 10, textAnchor: 'end' })}
                tickValues={rightAxisTicks}
                tickFormat={(n) => formatNumberToCompact({ num: n.valueOf() })}
              />
            )}
          </>
        )}
        {highlightedDate && (
          <>
            <Group>
              <Line
                from={{ x: dateScale(highlightedDate), y: 0 }}
                to={{ x: dateScale(highlightedDate), y: ChartHeight }}
                stroke={hoverLineColor}
                strokeWidth={1}
                pointerEvents="none"
                strokeDasharray="5,2"
              />
            </Group>
            {tooltipData?.points.map((point, i) => {
              if (point.date === null) {
                return null;
              }

              const value = Number(
                stacked
                  ? tooltipData?.points
                      .slice(i)
                      .filter(({ axis }) => axis === point.axis)
                      .reduce((sum, p) => sum + p.value, 0)
                  : point.value,
              );

              return (
                <circle
                  key={i}
                  cx={dateScale(point.date)}
                  cy={valueScales[point.axis](value)}
                  r={5}
                  fill={point.color}
                  stroke="white"
                  strokeOpacity={0.2}
                  strokeWidth={1}
                  pointerEvents="none"
                />
              );
            })}
          </>
        )}
        <Bar
          width={width}
          height={ChartHeight}
          fill="transparent"
          onTouchStart={onMouseMove}
          onTouchMove={onMouseMove}
          onMouseMove={onMouseMove}
          onMouseLeave={onMouseLeave}
        />
      </svg>

      {tooltipData && (
        <Tooltip
          top={ChartHeight * 0.3}
          left={dateScale(tooltipData.date)}
          tooltipData={tooltipData}
        />
      )}
    </div>
  );
};

interface TooltipData {
  date: Date;
  timezone: string;
  aggPeriod: TimeAggregationPeriod;
  points: {
    label: string;
    value: number;
    date: Date;
    color: string;
    axis: 'left' | 'right';
  }[];
}

const Tooltip = (props: { tooltipData: TooltipData; left: number; top: number }) => {
  const { tooltipData, left, top } = props;

  if (tooltipData.points.length === 0) {
    return null;
  }

  return (
    <ChartTooltip
      top={top}
      left={left}
      title={getFormattedDateForPeriod(
        tooltipData.date,
        tooltipData.aggPeriod,
        tooltipData.timezone,
      )}>
      <ul>
        {tooltipData.points.length > 1 && (
          <li>
            Total:{' '}
            <strong>
              {tooltipData.points
                .map((point) => point.value)
                .reduce((sum, value) => sum + value, 0)
                .toLocaleString(undefined, {
                  maximumFractionDigits: 2,
                })}
            </strong>
          </li>
        )}
        {tooltipData.points.map((point, i) => (
          <li key={i}>
            <span className={commonStyles.seriesMarker} style={{ background: point.color }} />
            {point.label ?? '-'}:{' '}
            <strong>
              {point.value.toLocaleString(undefined, {
                maximumFractionDigits: 2,
              })}
            </strong>
          </li>
        ))}
      </ul>
    </ChartTooltip>
  );
};

interface GroupedTimeSeriesChartProps {
  data: GroupedChartData<TimeSeriesData>;
  grouping: Grouping[];
  aggPeriod: TimeAggregationPeriod;
  stacked?: boolean;
  setStacked?: (stacked: boolean) => void;
  hideGrid?: boolean;
  timezone: string;
}

const getDates = (data: GroupedChartData<TimeSeriesData>): Date[] =>
  isArray(data)
    ? data.flatMap((item) => getDates(item.items))
    : data.chartData.items.map((item) => item.dateValue);

export const GroupedTimeSeriesChart = (props: GroupedTimeSeriesChartProps) => {
  const {
    data,
    grouping,
    aggPeriod,
    stacked = false,
    setStacked,
    hideGrid = false,
    timezone,
  } = props;

  const [highlightedDate, setHighlightedDate] = useState<Date | null>(null);
  const seriesDomain = getValueKeys(data);

  const colors: ChartColor[] = [
    { fill: barColor1, highlight: barColor1Accent },
    { fill: barColor2, highlight: barColor2Accent },
    { fill: barColor3, highlight: barColor3Accent },
    { fill: barColor4, highlight: barColor4Accent },
    { fill: barColor5, highlight: barColor5Accent },
    { fill: barColor6, highlight: barColor6Accent },
  ];
  const colorScale = scaleOrdinal<string, ChartColor>({
    domain: seriesDomain,
    range: seriesDomain.length === 1 ? [colors[1]] : colors,
  });

  const legend = getSeries(data).map(({ key, label }) => ({
    label,
    color: colorScale(key).highlight,
  }));

  const isStackable =
    getValueKeys(data, 'left').length > 1 || getValueKeys(data, 'right').length > 1;

  const groupCount = grouping.length;

  return (
    <ParentSize>
      {(parent) => {
        const dates = uniq(getDates(data));
        const dateScale = scaleTime({
          domain: [dates[0], dates[dates.length - 1]],
          range: [
            0,
            parent.width - (groupCount > 1 ? WidthPerGroup * (groupCount - SeriesDepth) : 0),
          ],
        });

        const leftMin = Math.min(0, getMinValue(data, 'left'));
        const leftMax = getMaxValue(data, 'left');
        const rightMin = Math.min(0, getMinValue(data, 'right'));
        const rightMax = getMaxValue(data, 'right');
        const valueScales = {
          // prettier-ignore
          left: scaleLinear<number>({
            round: true,
            nice: true,
            domain: [
              leftMin + leftMin * YValuePaddingCoef,
              leftMax + leftMax * YValuePaddingCoef
            ],
            range: [ChartHeight, 0],
          }),
          right: scaleLinear<number>({
            round: true,
            nice: true,
            domain: [
              rightMin + rightMin * YValuePaddingCoef,
              rightMax + rightMax * YValuePaddingCoef,
            ],
            range: [ChartHeight, 0],
          }),
        };

        const xTickSize = parent.width / dates.length;

        const handleMouseMove = (
          event: React.TouchEvent<SVGRectElement> | React.MouseEvent<SVGRectElement>,
        ) => {
          const { x } = localPoint(event) || { x: 0 };
          const hoveredDate = new Date(dateScale.invert(x - xTickSize / 2));
          const closestDate = findClosestDate(dates, hoveredDate) ?? last(dates);

          setHighlightedDate(new Date(closestDate));
        };

        return (
          <div>
            <GroupedChart
              width={parent.width}
              widthPerGroup={WidthPerGroup}
              data={data}
              renderChart={(chartData, chartWidth) => {
                return (
                  <TimeSeriesChart
                    dateScale={dateScale}
                    valueScales={valueScales}
                    colorScale={colorScale}
                    aggPeriod={aggPeriod}
                    highlightedDate={highlightedDate}
                    width={chartWidth}
                    height={ChartHeight}
                    hideGrid={hideGrid}
                    data={chartData}
                    stacked={stacked}
                    onMouseMove={handleMouseMove}
                    onMouseLeave={() => setHighlightedDate(null)}
                    timezone={timezone}
                  />
                );
              }}
            />
            {hideGrid ? null : (
              <svg width={parent.width} height={25}>
                <AxisBottom
                  left={Math.max(0, WidthPerGroup * (groupCount - 2))}
                  top={0}
                  scale={dateScale}
                  hideTicks
                  hideAxisLine
                  tickClassName={commonStyles.tickLabel}
                  tickLabelProps={() => ({
                    textAnchor: 'start',
                    dy: -4,
                  })}
                  tickValues={getTickValues(aggPeriod, dateScale)}
                  tickFormat={getTickFormatFn(aggPeriod)}
                />
              </svg>
            )}

            <div
              style={{
                marginLeft: Math.max(
                  HorizontalPadding,
                  WidthPerGroup * (grouping.length - SeriesDepth),
                ),
                marginRight: HorizontalPadding,
                marginTop: hideGrid ? VerticalPadding : 0,
              }}>
              <div className={commonStyles.chartFooter}>
                <ColoredLegend items={legend} />
                {isStackable ? (
                  <CheckboxWithLabel
                    className={[commonStyles.stackingCheckbox]}
                    checked={stacked}
                    onChange={() => setStacked && setStacked(stacked !== true)}>
                    Stacked
                  </CheckboxWithLabel>
                ) : null}
              </div>
            </div>
          </div>
        );
      }}
    </ParentSize>
  );
};
