import { compact, uniq } from 'lodash';
import { translateToPrql } from '@gosupersimple/penguino';
import { common } from '@gosupersimple/types';

import { unlessNil } from '@/lib/utils';

import {
  AggregationType,
  Fields,
  DereferencedPipelineOperation,
  Model,
  PipelineState,
  QueryVariables,
  CompositeFilterCondition,
  TimeAggregationPeriod,
  DereferencedPipeline,
  Pipeline,
  Exploration,
  Aggregation,
  Metric,
  Field,
  Relation,
  PipelineStateRelation,
} from '../types';
import {
  isExprAggregation,
  isKeyedAggregation,
  isMetricAggregation,
  isValueExpression,
} from './operation';
import { FieldGroup, dereferencePipeline, ensureUniqueFieldNames } from './utils';
import { getModelOrThrow, getModel, getJoinKeyOnBase } from '../model/utils';
import { convertFieldsForPenguinoContext } from '../utils/penguino';
import { getVariableExpression, isNumberType, isStringType } from '../utils';
import { findModelPropertyField } from '../edit-pipeline/utils/relation';

export interface PipelineStateContext {
  models: Model[];
  variables: QueryVariables;
  metrics: Metric[];
}

export class InvalidOperationError extends Error {
  constructor(message: string) {
    super(message);
    this.name = 'InvalidOperationError';
  }
}

export const InvalidModel = {
  modelId: 'invalid',
  name: 'Invalid model',
  primaryKey: [],
  properties: [],
  relations: [],
  labels: {},
};

const constructFieldRelation = (model: Model, key: string) =>
  model.primaryKey.includes(key)
    ? {
        modelId: model.modelId,
        name: model.name,
        key,
      }
    : unlessNil(
        model.relations.find(
          (relation) =>
            relation.type === 'hasOne' &&
            relation.joinStrategy &&
            'joinKeyOnBase' in relation.joinStrategy &&
            key === relation.joinStrategy?.joinKeyOnBase,
        ),
        (relation) => {
          return relation.joinStrategy && 'joinKeyOnRelated' in relation.joinStrategy
            ? {
                modelId: relation.modelId,
                name: relation.name,
                key: relation.joinStrategy.joinKeyOnRelated,
              }
            : undefined;
        },
      );

const constructField = (
  key: string,
  name: string,
  type: string,
  model?: Model,
  keyInRelatedModel?: string,
  precision?: TimeAggregationPeriod,
) => {
  return {
    key,
    name,
    type,
    pk: model?.primaryKey.includes(key) ?? false,
    model:
      model === undefined
        ? undefined
        : { modelId: model.modelId, name: model.name, propertyKey: keyInRelatedModel ?? key },
    relation:
      model === undefined ? undefined : constructFieldRelation(model, keyInRelatedModel ?? key),
    precision,
  };
};

export const getNextPipelineStateOrThrow = (
  state: PipelineState,
  operation: DereferencedPipelineOperation,
  operationIndex: number,
  ctx: PipelineStateContext,
): Omit<PipelineState, 'relations'> => {
  const { models, variables, metrics } = ctx;
  if (operation.disabled === true) {
    return state;
  }
  switch (operation.operation) {
    case 'filter': {
      validateFilterCondition(operation.parameters, [{ fields: state.fields }], variables);
      return state;
    }
    case 'deriveField': {
      const { returnType, format } = getExpressionResult(
        operation.parameters.value.expression,
        [{ fields: state.fields }],
        variables,
      );

      const fields = [
        ...state.fields,
        {
          key: operation.parameters.key,
          name: operation.parameters.fieldName,
          type: returnType ?? 'String',
          format: format,
        },
      ];

      return { ...state, fields };
    }

    case 'relationAggregate': {
      const relatedPipeline =
        'relation' in operation.parameters
          ? operation.parameters.pipeline === undefined
            ? {
                baseModelId: findValidRelation(
                  state.relations,
                  models,
                  operation.parameters.relation.key,
                  operation.parameters.relation.modelId,
                ).modelId,
                operations: [],
              }
            : operation.parameters.pipeline
          : operation.parameters.pipeline;

      const relatedPipelineFields = getFinalState(
        relatedPipeline.baseModelId,
        relatedPipeline.operations,
        ctx,
      ).fields;

      const relationKey =
        'relation' in operation.parameters ? operation.parameters.relation.key : undefined;

      validateSlice(operation.parameters.slice, relatedPipelineFields);
      operation.parameters.filters?.forEach((filter) =>
        validateFilterCondition(
          filter.parameters,
          [
            { fields: state.fields },
            { key: state.model.modelId, fields: state.fields }, // TODO: figure out 'current table' namespace
            { key: relationKey ?? relatedPipeline.baseModelId, fields: relatedPipelineFields },
          ],
          variables,
        ),
      );

      const fields = state.fields.concat(
        operation.parameters.aggregations.map((aggregation) => {
          const relatedField = relatedPipelineFields.find((field) => field.key === aggregation.key);
          if (isKeyedAggregation(aggregation) && relatedField === undefined) {
            throw new InvalidOperationError(`Field "${aggregation.key}" not found`);
          }
          return constructField(
            aggregation.property.key,
            aggregation.property.name,
            getResultTypeForAggregation(
              aggregation,
              [{ fields: relatedPipelineFields }],
              variables,
              metrics,
            ),
            aggregation.type === 'first'
              ? getModel(models, relatedField?.model?.modelId ?? relatedPipeline.baseModelId)
              : undefined,
            aggregation.type === 'first'
              ? (relatedField?.model?.propertyKey ?? aggregation.key)
              : undefined,
          );
        }),
      );

      return { ...state, fields };
    }

    case 'addRelatedColumn': {
      const relatedModel = getRelatedModelOrThrow(
        state,
        models,
        operation.parameters.relation.key,
        operation.parameters.relation.modelId,
      );
      const relatedPipelineFields = getFinalState(
        operation.parameters.pipeline?.baseModelId ?? relatedModel.modelId,
        operation.parameters.pipeline?.operations ?? [],
        ctx,
      ).fields;
      const newFields = operation.parameters.columns.map((column) => {
        const relatedField = relatedPipelineFields.find((field) => field.key === column.key);
        if (relatedField === undefined) {
          throw new InvalidOperationError(`Field "${column.key}" not found`);
        }
        return constructField(
          column.property.key,
          column.property.name,
          relatedField.type,
          getModel(models, relatedField?.model?.modelId ?? relatedModel.modelId),
          relatedField?.model?.propertyKey ?? column.key,
        );
      });

      const fields = state.fields.concat(newFields);

      return {
        ...state,
        fields,
      };
    }

    case 'groupAggregate': {
      validateSlice(operation.parameters.slice, state.fields);
      validateGroupDependencies(operation.parameters.groups, state.fields);

      if (operation.parameters.aggregations.length === 0) {
        return state;
      }

      const fields = [
        ...operation.parameters.groups.map(({ key, precision }) => {
          const field = state.fields.find((f) => f.key === key);
          if (field === undefined) {
            throw new InvalidOperationError(`Field "${key}" not found`);
          }

          return constructField(
            key,
            findFieldName(state.fields, key),
            precision === 'day_of_week' ? 'Integer' : field.type,
            getModel(models, field.model?.modelId ?? ''),
            undefined,
            precisionToTimeInterval(precision, variables),
          );
        }),
        ...operation.parameters.aggregations.map((aggregation) => {
          return {
            key: aggregation.property.key,
            name: aggregation.property.name,
            type: getResultTypeForAggregation(
              aggregation,
              [{ fields: state.fields }],
              variables,
              metrics,
            ),
          };
        }),
      ];

      return {
        ...state,
        fields,
      };
    }

    case 'switchToRelation': {
      const relatedModel = getRelatedModelOrThrow(
        state,
        models,
        operation.parameters.relation.key,
        operation.parameters.relation.modelId,
      );
      return {
        model: relatedModel,
        fields: relatedModel.properties.map(({ key, name, type }) =>
          constructField(key, name, type, relatedModel),
        ),
      };
    }

    case 'joinPipeline': {
      const joinStrategy = operation.parameters.joinStrategy;
      const baseModelId = operation.parameters.pipeline.baseModelId;
      const rightModel = getModel(models, baseModelId);
      if (rightModel === undefined) {
        throw new InvalidOperationError(`Model "${baseModelId}" not found`);
      }
      const rightModelName = rightModel.name;
      const pipelineState = getFinalState(
        baseModelId,
        operation.parameters.pipeline.operations,
        ctx,
      );
      const leftFieldKeys = state.fields.map((f) => f.key);
      const rightSideFields = pipelineState.fields.filter(
        (field) =>
          !(
            joinStrategy.joinKeyOnBase === joinStrategy.joinKeyOnRelated &&
            field.key === joinStrategy.joinKeyOnRelated
          ),
      );
      if (!leftFieldKeys.includes(joinStrategy.joinKeyOnBase)) {
        throw new InvalidOperationError(`Field "${joinStrategy.joinKeyOnBase}" not found`);
      }
      if (!pipelineState.fields.some((field) => field.key === joinStrategy.joinKeyOnRelated)) {
        throw new InvalidOperationError(`Field "${joinStrategy.joinKeyOnRelated}" not found`);
      }
      // This must mimick backend's field key & name renaming convention exactly
      const rightFieldMap = new Map(
        rightSideFields.map((f) => [
          f.key,
          leftFieldKeys.includes(f.key) ? `join_${operationIndex}_r_${f.key}` : f.key,
        ]),
      );

      const fields = state.fields.concat(
        ensureUniqueFieldNames(
          state.fields,
          rightSideFields.map((field) => ({
            ...field,
            key: rightFieldMap.get(field.key) ?? field.key,
          })),
          rightModelName,
        ),
      );

      return {
        ...state,
        fields,
      };
    }

    case 'sql': {
      const sqlModel: Model = {
        modelId: 'sql-model',
        name: 'Ad-Hoc model from SQL',
        primaryKey: [],
        properties: (operation.parameters.fields ?? []).map(({ key, type }) => ({
          key,
          name: key,
          type: type ?? 'String',
        })),
        relations: [],
        labels: {},
      };

      return { ...state, model: sqlModel, fields: sqlModel.properties };
    }

    default: {
      return state;
    }
  }
};

export const getNextPipelineState = (
  state: PipelineState,
  operation: DereferencedPipelineOperation,
  operationIndex: number,
  ctx: PipelineStateContext,
): Omit<PipelineState, 'relations'> => {
  try {
    return getNextPipelineStateOrThrow(state, operation, operationIndex, ctx);
  } catch (e) {
    if (!(e instanceof InvalidOperationError)) {
      throw e;
    }
    return state;
  }
};

const getFieldTypeByAggregationType = (type: AggregationType) => {
  switch (type) {
    case 'count':
    case 'count_distinct':
      return 'Integer';
    case 'metric':
      return 'Number';
    default:
      return null;
  }
};

export const getPipelineStateAtIndexOrThrow = (
  baseModelId: string,
  operations: DereferencedPipelineOperation[],
  index: number,
  ctx: PipelineStateContext,
): PipelineState => {
  const state = operations.slice(0, index).reduce(
    (state, operation, i) => {
      const nextState = getNextPipelineStateOrThrow(state, operation, i, ctx);
      return { ...nextState, relations: getAvailableRelations(nextState.fields, ctx.models) };
    },
    getBaseStateOrThrow(baseModelId, ctx.models),
  );

  return state;
};

export const getPipelineStateAtIndex = (
  baseModelId: string,
  operations: DereferencedPipelineOperation[],
  index: number,
  ctx: PipelineStateContext,
): PipelineState => {
  return operations.slice(0, index).reduce(
    (state, operation, i) => {
      const nextState = getNextPipelineState(state, operation, i, ctx);
      return { ...nextState, relations: getAvailableRelations(nextState.fields, ctx.models) };
    },
    getBaseState(baseModelId, ctx.models),
  );
};

export const getBaseStateOrThrow = (baseModelId: string, models: Model[]): PipelineState => {
  const baseModel = getModelOrThrow(models, baseModelId);
  const fields = baseModel.properties.map(({ key, name, type }) =>
    constructField(key, name, type, baseModel),
  );
  return {
    model: baseModel,
    relations: getAvailableRelations(fields, models),
    fields,
  };
};

export const getBaseState = (baseModelId: string, models: Model[]): PipelineState => {
  const baseModel = getModel(models, baseModelId);

  if (baseModel === undefined) {
    return {
      model: InvalidModel,
      relations: [],
      fields: [],
    };
  }

  const fields = baseModel.properties.map(({ key, name, type }) =>
    constructField(key, name, type, baseModel),
  );
  return {
    model: baseModel,
    relations: getAvailableRelations(fields, models),
    fields,
  };
};

export const getFinalStateOrThrow = (
  baseModelId: string,
  pipeline: DereferencedPipelineOperation[],
  ctx: PipelineStateContext,
): PipelineState => getPipelineStateAtIndexOrThrow(baseModelId, pipeline, pipeline.length, ctx);

export const getFinalState = (
  baseModelId: string,
  pipeline: DereferencedPipelineOperation[],
  ctx: PipelineStateContext,
): PipelineState => getPipelineStateAtIndex(baseModelId, pipeline, pipeline.length, ctx);

export const getPipelineFields = (
  pipeline: Pipeline,
  exploration: Exploration,
  ctx: PipelineStateContext,
) => getDereferencedPipelineFields(dereferencePipeline(pipeline, exploration), ctx);

export const getDereferencedPipelineFields = (
  pipeline: DereferencedPipeline,
  ctx: PipelineStateContext,
) => getFinalStateOrThrow(pipeline.baseModelId, pipeline.operations, ctx).fields;

const getAvailableRelations = (fields: Field[], models: Model[]) =>
  uniq(compact(fields.map((field) => field.model?.modelId))).flatMap((modelId) => {
    const model = getModel(models, modelId);
    if (model === undefined) {
      throw new InvalidOperationError(
        `Model "${modelId}" not found in ${models.map((m) => m.modelId).join(', ')}`,
      );
    }
    return (
      model.relations
        ?.filter((relation) => joinKeyOnBaseExists(fields, relation, model))
        .map((relation) => ({ ...relation, baseModelId: model.modelId })) ?? []
    );
  });

export const findValidRelation = (
  relations: PipelineStateRelation[],
  models: Model[],
  relationKey: string,
  baseModelId?: string,
) => {
  const matchingRelations = relations.filter(
    (r) => r.key === relationKey && (r.baseModelId === baseModelId || baseModelId === undefined),
  );
  if (matchingRelations.length === 0) {
    throw new InvalidOperationError(
      `Relation "${relationKey}" not found${baseModelId !== undefined ? ` in model ${baseModelId}` : ''}`,
    );
  }
  if (matchingRelations.length > 1) {
    throw new InvalidOperationError(
      `Relation key "${relationKey}" is ambiguous. Please provide a model ID.`,
    );
  }
  const relatedModel = getModel(models, matchingRelations[0].modelId);
  if (relatedModel === undefined) {
    throw new InvalidOperationError(`Model "${matchingRelations[0].modelId}" not found`);
  }
  return matchingRelations[0];
};

const joinKeyOnBaseExists = (fields: Field[], relation: Relation, model: Model) => {
  const joinKey = getJoinKeyOnBase(relation, model);
  return findModelPropertyField(fields, model.modelId, joinKey) !== undefined;
};

const getRelatedModelOrThrow = (
  state: PipelineState,
  models: Model[],
  relationKey: string,
  modelId?: string,
) => {
  const relation = findValidRelation(state.relations, models, relationKey, modelId);
  return getModelOrThrow(models, relation.modelId);
};

const validateFilterCondition = (
  condition: CompositeFilterCondition,
  fieldGroups: FieldGroup[],
  variables: QueryVariables,
): boolean => {
  if (condition.operator === 'and' || condition.operator === 'or') {
    return condition.operands.every((c) => validateFilterCondition(c, fieldGroups, variables));
  } else if (!('key' in condition)) {
    return true;
  }
  const field = fieldGroups
    .flatMap((group) => group.fields)
    .find((field) => field.key === condition.key);
  if (field === undefined) {
    throw new InvalidOperationError(`Field "${condition.key}" not found`);
  }

  if (isValueExpression(condition.value)) {
    const { returnType } = getExpressionResult(condition.value.expression, fieldGroups, variables);
    const typesMatch =
      returnType === field.type ||
      (isNumberType(field.type) && isNumberType(returnType)) ||
      (isStringType(field.type) && isStringType(returnType));
    if (!typesMatch) {
      throw new InvalidOperationError(`Expected type ${field.type}, got ${returnType}`);
    }
  }

  return true;
};

const validateSlice = (slice: common.Slice | undefined, fields: Fields) => {
  if (slice === undefined) {
    return true;
  }
  if (slice.limit !== undefined && slice.limit < 0) {
    throw new InvalidOperationError('Slice limit must be a positive number');
  }
  if (slice.offset !== undefined && slice.offset < 0) {
    throw new InvalidOperationError('Slice offset must be a positive number');
  }
  validateSorting(slice?.sort, fields);
};

const validateSorting = (sorting: common.Sorting[] | undefined, fields: Fields) => {
  if (sorting === undefined) {
    return true;
  }
  sorting.forEach((sort) => {
    if (!fields.some((field) => field.key === sort.key)) {
      throw new InvalidOperationError(`Field "${sort.key}" not found`);
    }
  });
};

const validateGroupDependencies = (grouping: common.Grouping[], fields: Fields) => {
  grouping.forEach(({ key }) => {
    if (!fields.some((field) => field.key === key)) {
      throw new InvalidOperationError(`Field "${key}" not found`);
    }
  });
};
export const findField = (fields: Fields, key: string) => fields.find((field) => field.key === key);

export const findFieldName = (fields: Fields, key: string) => findField(fields, key)?.name ?? key;

const getExpressionResult = (
  expression: string,
  fieldGroups: FieldGroup[],
  variables: QueryVariables,
) => {
  try {
    const { returnType, format } = translateToPrql(expression, {
      fields: convertFieldsForPenguinoContext(fieldGroups),
      variables,
    });
    return { returnType: returnType, format };
  } catch (e) {
    if (e instanceof Error) {
      throw new InvalidOperationError(e.message);
    }
    throw new InvalidOperationError(`Invalid expression "${expression}"`);
  }
};

const getResultTypeForAggregation = (
  aggregation: Aggregation,
  fieldGroups: FieldGroup[],
  variables: QueryVariables,
  metrics: Metric[],
) => {
  return (
    getFieldTypeByAggregationType(aggregation.type) ??
    (isKeyedAggregation(aggregation)
      ? getResultTypeforKeyedAggregation(aggregation, fieldGroups)
      : isExprAggregation(aggregation)
        ? (getExpressionResult(aggregation.value.expression, fieldGroups, variables).returnType ??
          'String')
        : isMetricAggregation(aggregation)
          ? getResultTypeForMetric(aggregation.metricId, metrics)
          : 'String')
  );
};

const getResultTypeforKeyedAggregation = (aggregation: Aggregation, fieldGroups: FieldGroup[]) => {
  const type = fieldGroups
    .flatMap((group) => group.fields)
    .find((field) => field.key === aggregation.key)?.type;
  if (type === undefined) {
    throw new InvalidOperationError(`Field "${aggregation.key}" not found`);
  }
  return type;
};

const getResultTypeForMetric = (metricId: string, metrics: Metric[]) => {
  if (!metrics.some((m) => m.metricId === metricId)) {
    throw new InvalidOperationError(`Metric "${metricId}" not found`);
  }
  return 'Number';
};

export const precisionToTimeInterval = (
  precision: common.TimeInterval | common.ExpressionValue | undefined,
  variables: common.QueryVariables,
) => {
  if (precision === undefined) {
    return undefined;
  }

  const timeIntervalValue = isValueExpression(precision)
    ? variables.find(({ key }) => getVariableExpression(key) === precision.expression)?.value
    : precision;
  const validation = common.timeInterval.safeParse(timeIntervalValue);
  if (!validation.success) {
    throw new InvalidOperationError(`Invalid precision value: ${timeIntervalValue}`);
  }
  return validation.data;
};
