import React, { useMemo } from 'react';
import { max } from 'lodash';

import { ScaleLinear, ScaleTime } from 'd3-scale';

import { Bar, Line } from '@visx/shape';
import { Grid, GridColumns } from '@visx/grid';
import { AxisLeft, AxisRight } from '@visx/axis';
import { Group } from '@visx/group';
import { getStringWidth } from '@visx/text';

import { formatNumberToCompact } from '@/lib/utils/number';
import { getFormattedDateForPeriod } from '@/lib/date';
import { ChartType, TimeAggregationPeriod } from '@/explore/types';

import { ChartTooltip } from '../common';
import { getHighlightedTickValues, getTickValues } from './utils';
import { TimeSeriesData } from '../grouped-chart/types';

import { AreaStack } from './area-stack';
import { LineStack } from './line-stack';
import { BarStack } from './bar-stack';

import { StackValues } from './stack-values';

import commonStyles, {
  gridLineColor,
  gridHighlightedLineColor,
  hoverLineColor,
} from '../charts.module.scss';

const NumYTicks = 7;

export type StackedChartData = {
  [key: string]: number | Date;
  date: Date;
}[];

export interface ChartColor {
  fill: string;
  highlight: string;
}

interface StackedTimeSeriesChartProps {
  width: number;
  height: number;
  data: TimeSeriesData;
  dateScale: ScaleTime<number, number>;
  valueScales: {
    left: ScaleLinear<number, number>;
    right: ScaleLinear<number, number>;
  };
  aggPeriod: TimeAggregationPeriod;
  highlightedDate: Date | null;
  onMouseMove?: (
    event: React.TouchEvent<SVGRectElement> | React.MouseEvent<SVGRectElement>,
  ) => void;
  onMouseLeave?: () => void;

  hideGrid: boolean;
  timezone: string;
}

export const StackedTimeSeriesChart = (props: StackedTimeSeriesChartProps) => {
  const {
    width,
    height,
    data,
    dateScale,
    valueScales,
    aggPeriod,
    highlightedDate,
    onMouseMove,
    onMouseLeave,
    hideGrid,
    timezone,
  } = props;

  const tooltipData: TooltipData | null = useMemo(() => {
    if (highlightedDate === null) {
      return null;
    }

    const closestDataPoint = data.items.find(
      ({ dateValue }) => dateValue.getTime() === highlightedDate.getTime(),
    );

    if (closestDataPoint === undefined) {
      return null;
    }

    const points = data.series.map(({ key, label, axis, chartType, color }) => ({
      label,
      value: closestDataPoint.values[key] ?? 'N/A',
      date: closestDataPoint.dateValue,
      color,
      axis,
      chartType,
    }));

    return {
      date: highlightedDate,
      timezone,
      aggPeriod,
      points,
    };
  }, [highlightedDate, data, timezone, aggPeriod]);

  const yAxisWidths = useMemo(
    () =>
      Object.entries(valueScales).reduce(
        (acc, [key, valueScale]) => {
          return {
            ...acc,
            [key]:
              max(
                valueScale
                  .ticks(NumYTicks)
                  .map((i) => formatNumberToCompact({ num: i.valueOf() }))
                  .map((i) => getStringWidth(i)),
              ) ?? 0,
          };
        },
        { left: 0, right: 0 },
      ),
    [valueScales],
  );

  const domainSizes = {
    left: valueScales.left.domain()[1] - valueScales.left.domain()[0],
    right: valueScales.right.domain()[1] - valueScales.right.domain()[0],
  };

  // Default to right as main axis in case no series on left axis
  const primaryValueScale = domainSizes.left > 0 ? valueScales.left : valueScales.right;

  // Force right axis ticks to align with primary axis ticks
  const rightAxisTicks = primaryValueScale
    .ticks(NumYTicks)
    .map((tick) => valueScales.right.invert(primaryValueScale(tick)));

  return (
    <div className={commonStyles.graph}>
      <svg width={width} height={height}>
        {hideGrid ? null : (
          <>
            <Grid
              width={width}
              height={height}
              xScale={dateScale}
              yScale={primaryValueScale}
              rowTickValues={primaryValueScale.ticks(NumYTicks)}
              stroke={gridLineColor}
              columnTickValues={getTickValues(aggPeriod, dateScale)}
            />
            <GridColumns
              width={width}
              height={height}
              scale={dateScale}
              stroke={gridHighlightedLineColor}
              tickValues={getHighlightedTickValues(aggPeriod, dateScale)}
            />
          </>
        )}

        {(['left', 'right'] as const).map((axis) => {
          const areaKeys = data.series
            .filter((s) => s.chartType === 'area' && s.axis === axis)
            .map(({ key }) => key);
          const lineKeys = data.series
            .filter((s) => s.chartType === 'line' && s.axis === axis)
            .map(({ key }) => key);
          const barKeys = data.series
            .filter((s) => s.chartType === 'bar' && s.axis === axis)
            .map(({ key }) => key);

          return domainSizes[axis] <= 0 ? null : (
            <React.Fragment key={axis}>
              <AreaStack
                keys={areaKeys}
                data={data}
                valueScale={valueScales[axis]}
                dateScale={dateScale}
              />

              <LineStack
                keys={lineKeys}
                data={data}
                valueScale={valueScales[axis]}
                dateScale={dateScale}
              />

              <BarStack
                keys={barKeys}
                data={data}
                valueScale={valueScales[axis]}
                dateScale={dateScale}
              />

              <StackValues
                keys={areaKeys}
                data={data}
                valueScale={valueScales[axis]}
                dateScale={dateScale}
                reverse
                avoidOverflow
              />

              <StackValues
                keys={lineKeys}
                data={data}
                valueScale={valueScales[axis]}
                dateScale={dateScale}
                reverse
                avoidOverflow
              />

              <StackValues
                keys={barKeys}
                data={data}
                valueScale={valueScales[axis]}
                dateScale={dateScale}
              />
            </React.Fragment>
          );
        })}
        {hideGrid ? null : (
          <>
            {domainSizes.left > 0 && (
              <AxisLeft
                left={yAxisWidths.left + 4}
                scale={valueScales.left}
                hideTicks
                hideAxisLine
                tickClassName={commonStyles.tickLabel}
                tickLabelProps={() => ({ dx: 4, dy: 10, textAnchor: 'end' })}
                tickValues={valueScales.left.ticks(NumYTicks)}
                tickFormat={(n) => formatNumberToCompact({ num: n.valueOf() })}
              />
            )}
            {domainSizes.right > 0 && (
              <AxisRight
                left={width - yAxisWidths.right - 4}
                scale={valueScales.right}
                hideTicks
                hideAxisLine
                tickClassName={commonStyles.tickLabel}
                tickLabelProps={() => ({ dx: 7, dy: 10, textAnchor: 'end' })}
                tickValues={rightAxisTicks}
                tickFormat={(n) => formatNumberToCompact({ num: n.valueOf() })}
              />
            )}
          </>
        )}
        {highlightedDate && (
          <>
            <Group>
              <Line
                from={{ x: dateScale(highlightedDate), y: 0 }}
                to={{ x: dateScale(highlightedDate), y: height }}
                stroke={hoverLineColor}
                strokeWidth={1}
                pointerEvents="none"
                strokeDasharray="5,2"
              />
            </Group>
            {tooltipData?.points.map((point, i) => {
              if (point.date === null || point.chartType === 'bar') {
                return null;
              }

              const value = Number(
                tooltipData?.points
                  .slice(i)
                  .filter(({ axis }) => axis === point.axis)
                  .reduce((sum, p) => (point.chartType === p.chartType ? sum + p.value : sum), 0),
              );

              return (
                <circle
                  key={i}
                  cx={dateScale(point.date)}
                  cy={valueScales[point.axis](value)}
                  r={5}
                  fill={point.color}
                  stroke="white"
                  strokeOpacity={0.2}
                  strokeWidth={1}
                  pointerEvents="none"
                />
              );
            })}
          </>
        )}
        <Bar
          width={width}
          height={height}
          fill="transparent"
          onTouchStart={onMouseMove}
          onTouchMove={onMouseMove}
          onMouseMove={onMouseMove}
          onMouseLeave={onMouseLeave}
        />
      </svg>

      {tooltipData && (
        <Tooltip top={height * 0.3} left={dateScale(tooltipData.date)} tooltipData={tooltipData} />
      )}
    </div>
  );
};

interface TooltipData {
  date: Date;
  timezone: string;
  aggPeriod: TimeAggregationPeriod;
  points: {
    label: string;
    value: number;
    date: Date;
    color: string;
    axis: 'left' | 'right';
    chartType: ChartType;
  }[];
}

const Tooltip = (props: { tooltipData: TooltipData; left: number; top: number }) => {
  const { tooltipData, left, top } = props;

  if (tooltipData.points.length === 0) {
    return null;
  }

  return (
    <ChartTooltip
      top={top}
      left={left}
      title={getFormattedDateForPeriod(
        tooltipData.date,
        tooltipData.aggPeriod,
        tooltipData.timezone,
      )}>
      <ul>
        {tooltipData.points.length > 1 && (
          <li>
            Total:{' '}
            <strong>
              {tooltipData.points
                .map((point) => point.value)
                .reduce((sum, value) => sum + value, 0)
                .toLocaleString(undefined, {
                  maximumFractionDigits: 2,
                })}
            </strong>
          </li>
        )}
        {tooltipData.points.map((point, i) => (
          <li key={i}>
            <span className={commonStyles.seriesMarker} style={{ background: point.color }} />
            {point.label ?? '-'}:{' '}
            <strong>
              {point.value.toLocaleString(undefined, {
                maximumFractionDigits: 2,
              })}
            </strong>
          </li>
        ))}
      </ul>
    </ChartTooltip>
  );
};
