import Decimal from "decimal.js";
import { dayjs } from "lib/dayjs";
import {
  getUserFacingErrorMessage,
  giveUserFacingErrorMessage,
} from "app/lib/errors/errorHandling";
import { useNavigate } from "app/lib/useNavigate";
import React, { useEffect, useState } from "react";

import {
  Body,
  DateInput,
  Headline,
  Input,
  Select,
  Tooltip,
} from "design-system";

import { Button } from "components/Button";

import { DeprecatedCopyableID } from "components/deprecated/CopyableID";
import { DeprecatedTotals } from "components/deprecated/Invoice/components/Totals";
import { AppShell } from "components/AppShell";
import { useSnackbar } from "components/deprecated/Snackbar";
import { ErrorEmptyState } from "app/lib/errors/ErrorEmptyState";
import { useFeatureFlag } from "app/lib/launchdarkly";
import { useOptionalParam, useRequiredParam } from "app/lib/routes/params";
import { renderDate } from "lib/time";
import {
  InvoiceCorrectionItem,
  InvoiceCorrectionItemTypeEnum,
  SequenceNumberInput,
  InvoiceStatusEnum,
} from "types/generated-graphql/__types__";
import NotFoundPage from "app/pages/404";
import { clearCustomerInvoicesFromCache } from "app/pages/deprecated/Customer/tabs/Credits/lib/cache";
import { ContractCorrectionTable } from "./components/ContractCorrectionTable";
import { ProductChargeLineItemCorrection } from "./components/LineItem";
import {
  CorrectionDataLineItemFragment,
  CreateInvoiceCorrectionMutation,
  InvoiceCorrectionDataQuery,
  useContractUsageInvoiceCorrectionIdsQuery,
  useCreateInvoiceCorrectionMutation,
  useInvoiceCorrectionDataQuery,
} from "./correctInvoice.graphql";
import { CORRECTION_REASONS } from "./correction_reasons";
import { prepareVariablesForCorrectInvoice } from "./helpers";
import {
  createContractLineItemPointer,
  createLineItemPointer,
  lineItemPointerToString,
} from "./pointers";
import {
  ArrearsCorrectableLineItem,
  ContractCorrectableLineItem,
  CorrectableLineItem,
  CorrectableSubLineItem,
  CorrectionEditMode,
} from "./types";
import { Breadcrumbs } from "app/lib/breadcrumbs";
import { twMerge } from "twMerge";

function lineItemIsArrearsCorrectable(
  li: CorrectionDataLineItemFragment,
): li is ArrearsCorrectableLineItem {
  return (
    li.__typename === "ProductChargeLineItem" ||
    li.__typename === "GroupedProductChargeLineItem"
  );
}

function lineItemIsContractCorrectable(
  li: CorrectionDataLineItemFragment,
): li is ContractCorrectableLineItem {
  return (
    li.__typename === "ContractChargeLineItem" ||
    li.__typename === "ContractCommitLineItem" ||
    li.__typename === "ContractDiscountLineItem" ||
    li.__typename === "ContractUsageLineItem" ||
    li.__typename === "ContractAWSRoyaltyLineItem" ||
    li.__typename === "ContractGCPRoyaltyLineItem"
  );
}

function getUpdatedLineItemsForArrearsInvoice(
  invoice: InvoiceCorrectionDataQuery["mri_invoice"],
  correctionPreview:
    | CreateInvoiceCorrectionMutation["create_invoice_correction"]
    | undefined,
) {
  if (!invoice || invoice.__typename !== "ArrearsInvoice") {
    return {
      arrearsUpdatedCorrectableLineItems: [],
      arrearsCorrectedLineItems: {},
    };
  }

  const correctableLineItems =
    invoice.line_items.filter<ArrearsCorrectableLineItem>(
      lineItemIsArrearsCorrectable,
    );

  const updatedItems = Object.fromEntries<CorrectableSubLineItem>(
    correctableLineItems.flatMap((li) =>
      li.sub_line_items.map((sli) => [
        lineItemPointerToString(createLineItemPointer(li, sli)),
        sli,
      ]),
    ),
  );

  invoice.corrections.forEach((c) => {
    // Skip voided corrections when computing new line items
    if (c.status === InvoiceStatusEnum.Void) {
      return;
    }
    c.line_items.forEach((li) => {
      if (lineItemIsArrearsCorrectable(li)) {
        li.sub_line_items.forEach((sli) => {
          // Only update quantity for usage charges
          if (sli.__typename === "ChargeLineItem") {
            const ptr = lineItemPointerToString(createLineItemPointer(li, sli));
            updatedItems[ptr] = {
              ...updatedItems[ptr],
              quantity: new Decimal(updatedItems[ptr].quantity)
                .add(sli.quantity)
                .toString(),
              total: new Decimal(updatedItems[ptr].total)
                .add(sli.total)
                .toString(),
            };
          }
        });
      }
    });
  });

  const arrearsUpdatedCorrectableLineItems: CorrectableLineItem[] =
    correctableLineItems.map<CorrectableLineItem>((li) => {
      const subLineItems = li.sub_line_items.map((sli) => {
        const ptr = lineItemPointerToString(createLineItemPointer(li, sli));
        if (updatedItems[ptr]) {
          return {
            ...sli,
            quantity: updatedItems[ptr].quantity,
            total: updatedItems[ptr].total,
          } as CorrectableSubLineItem;
        }
      });
      return { ...li, sub_line_items: subLineItems } as CorrectableLineItem;
    });

  const arrearsCorrectedLineItems = Object.fromEntries<CorrectableSubLineItem>(
    (
      correctionPreview?.line_items.filter<ArrearsCorrectableLineItem>(
        lineItemIsArrearsCorrectable,
      ) ?? []
    ).flatMap((li) =>
      "sub_line_items" in li
        ? li.sub_line_items.map((sli) => [
            lineItemPointerToString(createLineItemPointer(li, sli)),
            sli,
          ])
        : [],
    ),
  );

  return {
    arrearsUpdatedCorrectableLineItems,
    arrearsCorrectedLineItems,
  };
}

function flattenContractLineItems(lineItems: ContractCorrectableLineItem[]) {
  const flattenedLineItems: {
    [pointer: string]: ContractCorrectableLineItem;
  } = {};
  lineItems.forEach((li) => {
    if (
      li.__typename === "ContractCommitLineItem" ||
      li.__typename === "ContractDiscountLineItem" ||
      li.__typename === "ContractChargeLineItem" ||
      li.__typename === "ContractUsageLineItem" ||
      li.__typename === "ContractAWSRoyaltyLineItem" ||
      li.__typename === "ContractGCPRoyaltyLineItem"
    ) {
      const key =
        li.__typename +
        lineItemPointerToString(createContractLineItemPointer(li));

      if (!flattenedLineItems[key]) {
        flattenedLineItems[key] = {
          ...li,
          quantity: "0",
          total: "0",
        };
      }

      flattenedLineItems[key].quantity = new Decimal(
        flattenedLineItems[key].quantity,
      )
        .add(new Decimal(li.quantity))
        .toString();
      flattenedLineItems[key].total = new Decimal(flattenedLineItems[key].total)
        .add(new Decimal(li.total))
        .toString();
    }
  });
  return flattenedLineItems;
}

function getUpdatedLineItemsForContractInvoice(
  invoice: InvoiceCorrectionDataQuery["mri_invoice"],
  correctionPreview:
    | CreateInvoiceCorrectionMutation["create_invoice_correction"]
    | undefined,
): {
  originalLineItem?: ContractCorrectableLineItem;
  correctedLineItem?: ContractCorrectableLineItem;
}[] {
  if (
    !invoice ||
    (invoice.__typename !== "ContractScheduledInvoice" &&
      invoice.__typename !== "ContractUsageInvoice")
  ) {
    return [];
  }

  const flattenedLineItems = flattenContractLineItems(
    invoice.line_items.filter<ContractCorrectableLineItem>(
      lineItemIsContractCorrectable,
    ),
  );
  const flattenedCorrectionLineItems = flattenContractLineItems(
    (correctionPreview?.line_items ?? []).filter<ContractCorrectableLineItem>(
      lineItemIsContractCorrectable,
    ),
  );

  invoice.corrections.forEach((c) => {
    // Skip voided corrections when computing new line items
    if (c.status === InvoiceStatusEnum.Void) {
      return;
    }
    const priorCorrectedLineItems = flattenContractLineItems(
      c.line_items.filter<ContractCorrectableLineItem>(
        lineItemIsContractCorrectable,
      ),
    );

    Object.entries(priorCorrectedLineItems).forEach(([key, li]) => {
      if (flattenedLineItems[key]) {
        flattenedLineItems[key].quantity = new Decimal(
          flattenedLineItems[key].quantity,
        )
          .add(new Decimal(li.quantity))
          .toString();
        flattenedLineItems[key].total = new Decimal(
          flattenedLineItems[key].total,
        )
          .add(new Decimal(li.total))
          .toString();
      } else {
        flattenedLineItems[key] = li;
      }
    });
  });

  const lineItemsWithAssociatedCorrections = Object.entries(
    flattenedLineItems,
  ).map(([key, li]) => {
    const correction = flattenedCorrectionLineItems[key];
    delete flattenedCorrectionLineItems[key];
    return {
      originalLineItem: li,
      correctedLineItem: correction,
    };
  });

  const correctionsWithoutLineItems = Object.values(
    flattenedCorrectionLineItems,
  ).map((li) => {
    return {
      originalLineItem: undefined,
      correctedLineItem: li,
    };
  });

  return [
    ...lineItemsWithAssociatedCorrections,
    ...correctionsWithoutLineItems,
  ];
}

export const CorrectInvoice: React.FC = () => {
  const invoiceId = useRequiredParam("invoiceId");
  const customerId = useRequiredParam("customerId");
  const customerPlanId = useOptionalParam("customerPlanId");
  const contractId = useOptionalParam("contractId");
  const navigate = useNavigate();
  const pushMessage = useSnackbar();

  const invoiceCorrectionsAllowed = useFeatureFlag(
    "invoice-corrections",
    false,
  );

  const [issueDate, setIssueDate] = useState<Date | undefined>();
  const [invoice, setInvoice] = useState<
    InvoiceCorrectionDataQuery["mri_invoice"] | null
  >(null);

  const { data, loading, error } = useInvoiceCorrectionDataQuery({
    variables: {
      invoice_id: invoiceId,
    },
    skip: !invoiceCorrectionsAllowed || invoice,
  });

  const correctionIdsQueryVars =
    invoice?.__typename === "ContractUsageInvoice"
      ? {
          contracts: {
            include: [invoice.contract.id],
          },
          inclusiveStartDate: invoice.inclusive_start_date,
          exclusiveEndDate: invoice.exclusive_end_date,
        }
      : {
          contracts: null,
          inclusiveStartDate: null,
          exclusiveEndDate: null,
        };

  const { data: priorCorrectionsData, error: priorCorrectionsError } =
    useContractUsageInvoiceCorrectionIdsQuery({
      variables: {
        customerByPkId: customerId,
        ...correctionIdsQueryVars,
      },
      skip:
        !invoiceCorrectionsAllowed ||
        invoice?.__typename !== "ContractUsageInvoice",
    });

  const [corrections, setCorrections] = useState<
    Record<string, InvoiceCorrectionItem & { input?: number }>
  >({});

  const [reason, setReason] = useState("");
  const [memo, setMemo] = useState("");

  const [correctMutation, correctResult] = useCreateInvoiceCorrectionMutation({
    update(cache, _, { variables }) {
      if (variables?.preview) {
        return;
      }

      cache.evict({ fieldName: "mri_invoice" });
      cache.evict({ fieldName: "mri_invoices" });
      cache.evict({ fieldName: "Customer" });
      cache.evict({ fieldName: "Customer_by_pk" });
      cache.gc();
    },
  });

  const [editMode, setEditMode] =
    React.useState<CorrectionEditMode>("deltaTotal");

  useEffect(() => {
    if (correctResult.error) {
      pushMessage({
        type: "error",
        content: `Failed to correct invoice: ${correctResult.error.message}`,
      });
    }
  }, [
    JSON.stringify({
      data: correctResult.data,
      loading: correctResult.loading,
      error: correctResult.error,
    }),
  ]);

  if (loading) {
    return null;
  }

  if (error) {
    return <ErrorEmptyState title="Could not fetch invoice" error={error} />;
  }
  if (priorCorrectionsError) {
    return (
      <ErrorEmptyState
        title="Could not fetch prior invoice corrections"
        error={priorCorrectionsError}
      />
    );
  }

  if (!invoice && data?.mri_invoice) {
    setInvoice(data.mri_invoice);
  }

  if (!invoice || !invoiceCorrectionsAllowed) {
    return <NotFoundPage />;
  }

  if (
    invoice.__typename !== "ArrearsInvoice" &&
    invoice.__typename !== "ContractScheduledInvoice" &&
    invoice.__typename !== "ContractUsageInvoice"
  ) {
    return (
      <ErrorEmptyState
        title="Only arrears and contract invoices can be corrected at this time"
        error={
          new Error(
            "Only arrears and contract invoices can be corrected at this time",
          )
        }
      />
    );
  }

  if (priorCorrectionsError) {
    return (
      <ErrorEmptyState
        title="Could not fetch prior invoice corrections"
        error={priorCorrectionsError}
      />
    );
  }
  let sequenceNumbers: SequenceNumberInput[] = [
    {
      original_invoice_id: invoice.id,
      sequence_number: invoice.corrections.length,
    },
  ];

  if (
    priorCorrectionsData?.Customer_by_pk?.mri_invoices &&
    invoice.__typename === "ContractUsageInvoice"
  ) {
    sequenceNumbers =
      priorCorrectionsData?.Customer_by_pk?.mri_invoices.invoices.flatMap(
        (i) => {
          return i.__typename === "ContractUsageInvoice"
            ? [
                {
                  original_invoice_id: i.id,
                  sequence_number: i.corrections.length,
                },
              ]
            : [];
        },
      );
  }

  const correctionPreview = correctResult.data?.create_invoice_correction;

  const { arrearsUpdatedCorrectableLineItems, arrearsCorrectedLineItems } =
    getUpdatedLineItemsForArrearsInvoice(invoice, correctionPreview);

  const contractLineItemsAndCorrections = getUpdatedLineItemsForContractInvoice(
    invoice,
    correctionPreview,
  );

  if (
    !arrearsUpdatedCorrectableLineItems.length &&
    !contractLineItemsAndCorrections.length
  ) {
    return (
      <ErrorEmptyState
        title="Invoice has no correctable line items"
        error={new Error("Invoice has no correctable line items")}
      />
    );
  }

  let isResellerUsageInvoice = false;
  if (
    invoice.__typename === "ContractUsageInvoice" &&
    invoice.line_items.find(
      (li) =>
        li.__typename === "ContractAWSRoyaltyLineItem" ||
        li.__typename === "ContractGCPRoyaltyLineItem",
    )
  ) {
    isResellerUsageInvoice = true;
  }

  const canSubmit =
    !correctResult.loading &&
    correctResult.called &&
    reason !== "" &&
    memo !== "";
  const submitButtonTooltipContent = !reason
    ? "Please select a reason"
    : !memo
      ? "Please enter a memo"
      : !correctResult.called
        ? "Please issue full refund or calculate new invoice first"
        : "";

  return (
    <AppShell
      title="Correct Invoice"
      headerProps={{
        // todo (naman): probably want the full customer / id / invoice / id / correct breadcrumbs
        breadcrumbs: Breadcrumbs.from(
          {
            label: "Customers",
            routePath: "/customers",
          },
          {
            label: invoice.customer.name,
            routePath: `/customers/${customerId}`,
          },
          { label: "Invoices", routePath: `/customers/${customerId}/invoices` },
          {
            label: invoice.id,
            routePath: `/customers/${customerId}/invoices/${invoice.id}`,
          },
        ),
      }}
    >
      <div className={twMerge("-mr-12 flex grow flex-col overflow-auto pr-12")}>
        <div className={twMerge("border-gray-100 border-b pb-8")}>
          <Headline level={6}>
            Step 1: Reason for correcting the invoice
          </Headline>
        </div>
        <div>
          <Body level={1} className="text-gray-700 my-8">
            Enter both a business reason and a memo.
          </Body>
          <div className="my-8 flex items-center space-x-8">
            <div className="w-2/5">
              <Select
                name=""
                placeholder="Select reason"
                options={CORRECTION_REASONS.map((r) => ({
                  label: r,
                  value: r,
                }))}
                multiSelect={false}
                value={reason}
                onChange={setReason}
              />
            </div>
            <Input
              placeholder="Enter memo"
              value={memo}
              onChange={setMemo}
              className="w-2/5"
            />
          </div>
        </div>
        <div className="border-gray-100 mt-24 border-b pb-8">
          <Headline level={6}>Step 2: Enter invoice corrections</Headline>
        </div>
        <div>
          <Body level={1} className="text-gray-700 my-8">
            Enter the correct amount for each charge.
          </Body>
          {isResellerUsageInvoice && (
            <>
              <Body level={1} className="my-8 text-warning-600">
                Note that correcting a usage line item may also correct
                associated composite charges on other invoices, if any.
              </Body>
            </>
          )}
          <div className="flex justify-between">
            <div>
              <Headline level={5}>
                {`Invoice correction for ${invoice.customer.name} ${renderDate(
                  invoice.__typename === "ContractScheduledInvoice"
                    ? new Date(invoice.issued_at)
                    : dayjs
                        .utc(invoice.exclusive_end_date)
                        .subtract(1, "s")
                        .toDate(),
                  {
                    isUtc: true,
                    excludeUtcLabel: true,
                  },
                )}`}
              </Headline>
              <div className="text-gray-700 flex">
                <Body level={2} className="mr-4">
                  Invoice ID:
                </Body>
                <DeprecatedCopyableID id={invoiceId} label="invoice ID" />
              </div>
            </div>
            <div className="mb-12 self-end">
              <Button
                className="mb-12 mr-12 self-end"
                onClick={() => {
                  const contractRefundCorrections =
                    contractLineItemsAndCorrections.flatMap(
                      ({ originalLineItem }) => {
                        if (!originalLineItem) {
                          return [];
                        }

                        const isCompositeProduct =
                          originalLineItem.__typename ===
                            "ContractUsageLineItem" &&
                          originalLineItem?.product?.__typename ===
                            "CompositeProductListItem";
                        if (isCompositeProduct) {
                          return [];
                        }

                        return [
                          {
                            input: 0,
                            type: InvoiceCorrectionItemTypeEnum.TotalChange,
                            contract_line_item_pointer:
                              createContractLineItemPointer(originalLineItem),
                            value: "0",
                          },
                        ];
                      },
                    );

                  const arrearsRefundCorrections =
                    arrearsUpdatedCorrectableLineItems
                      .filter<ArrearsCorrectableLineItem>(
                        lineItemIsArrearsCorrectable,
                      )
                      .flatMap((lineItem) =>
                        lineItem.sub_line_items.flatMap((subLineItem) => {
                          if (subLineItem.__typename !== "ChargeLineItem") {
                            return [];
                          }

                          return [
                            {
                              input: 0,
                              type: InvoiceCorrectionItemTypeEnum.QuantityChange,
                              line_item_pointer: createLineItemPointer(
                                lineItem,
                                subLineItem,
                              ),
                              value: "0",
                            },
                          ];
                        }),
                      );

                  const refundCorrections = [
                    ...contractRefundCorrections,
                    ...arrearsRefundCorrections,
                  ];

                  setEditMode("newTotal");
                  const correctionItems = Object.fromEntries(
                    refundCorrections.map((c) => [
                      lineItemPointerToString(
                        "line_item_pointer" in c
                          ? c.line_item_pointer
                          : c.contract_line_item_pointer,
                      ),
                      c,
                    ]),
                  );
                  setCorrections(correctionItems);

                  let variables;
                  try {
                    variables = prepareVariablesForCorrectInvoice(
                      invoice.id,
                      reason,
                      memo,
                      sequenceNumbers,
                      correctionItems,
                      true,
                    );
                  } catch (e) {
                    const error = getUserFacingErrorMessage(e);
                    pushMessage({
                      type: "error",
                      content: `Failed to correct invoice: ${error}`,
                    });
                  }

                  if (variables) {
                    void correctMutation({
                      variables,
                    });
                  }
                }}
                disabled={correctResult.loading}
                text="Issue full refund"
                theme="primary"
                leadingIcon="minus"
              />
              <Button
                onClick={() => {
                  let variables;
                  try {
                    variables = prepareVariablesForCorrectInvoice(
                      invoice.id,
                      reason,
                      memo,
                      sequenceNumbers,
                      corrections,
                      true,
                    );
                  } catch (e) {
                    const error = getUserFacingErrorMessage(e);
                    pushMessage({
                      type: "error",
                      content: `Failed to correct invoice: ${error}`,
                    });
                  }

                  if (variables) {
                    void correctMutation({
                      variables,
                    });
                  }
                }}
                disabled={correctResult.loading}
                loading={correctResult.loading}
                className="mb-12 self-end"
                text="Calculate new invoice"
                theme="primary"
                leadingIcon="calculator"
              />
            </div>
          </div>
          {arrearsUpdatedCorrectableLineItems
            .filter<ArrearsCorrectableLineItem>(lineItemIsArrearsCorrectable)
            .map((li, i) => (
              <ProductChargeLineItemCorrection
                lineItem={li}
                key={i}
                corrections={corrections}
                correctItem={(item) => {
                  if (!item.line_item_pointer) {
                    throw new Error("Missing line item pointer");
                  }
                  setCorrections({
                    ...corrections,
                    [lineItemPointerToString(item.line_item_pointer)]: item,
                  });
                }}
                removeCorrection={(pointer) => {
                  const newCorrections = { ...corrections };
                  delete newCorrections[lineItemPointerToString(pointer)];
                  setCorrections(newCorrections);
                }}
                clearCorrections={() => setCorrections({})}
                correctedLineItems={arrearsCorrectedLineItems}
                editMode={editMode}
                onEditModeChanged={setEditMode}
              />
            ))}
          {!contractLineItemsAndCorrections.length || (
            <ContractCorrectionTable
              contractLineItemsAndCorrections={contractLineItemsAndCorrections}
              corrections={corrections}
              onCorrectionsChanged={(corrections) =>
                setCorrections(corrections)
              }
              editMode={editMode}
              onEditModeChanged={setEditMode}
            />
          )}
          {correctionPreview ? (
            <DeprecatedTotals invoice={correctionPreview} />
          ) : null}
        </div>
        <div className="border-gray-100 mt-24 border-b pb-8">
          <Headline level={6}>
            (Optional) Step 3: Enter custom issue date
          </Headline>
        </div>
        <div>
          <Body level={1} className="text-gray-700 my-8">
            The issue date for the correction invoice must be greater than or
            equal to the original invoice issue date, and must be less than or
            equal to the current date.
          </Body>
          <Body level={1} className="text-gray-700 my-8">
            If you do not enter a custom issue date, the current date will be
            used.
          </Body>
          <div className="flex gap-4">
            <DateInput
              value={issueDate}
              onChange={setIssueDate}
              minDate={dayjs.utc(invoice.issued_at).toDate()}
              maxDate={new Date()}
              isUTC
            />
          </div>
        </div>
      </div>
      <div className="shadow-inner -mx-12 flex grow-0 flex-row items-center justify-end gap-8 bg-white px-24 py-12">
        <Button onClick={() => navigate(-1)} text="Cancel" theme="linkGray" />
        <Tooltip disabled={canSubmit} content={submitButtonTooltipContent}>
          <Button
            onClick={async () => {
              try {
                const variables = prepareVariablesForCorrectInvoice(
                  invoice.id,
                  reason,
                  memo,
                  sequenceNumbers,
                  corrections,
                  false,
                  issueDate,
                );

                const result = await correctMutation({
                  variables,
                  update(cache) {
                    clearCustomerInvoicesFromCache(cache, customerId);
                  },
                });
                if (Number(result.data?.create_invoice_correction.total) > 0) {
                  throw giveUserFacingErrorMessage(
                    new Error(),
                    "New total cannot be positive",
                  );
                }

                navigate(
                  `/customers/${invoice.customer.id}${
                    customerPlanId
                      ? `/plans/${customerPlanId}/`
                      : contractId
                        ? `/contracts/${contractId}/`
                        : "/"
                  }invoices/${result.data?.create_invoice_correction.id}`,
                );
              } catch (e) {
                const error = getUserFacingErrorMessage(e);
                pushMessage({
                  type: "error",
                  content: `Failed to correct invoice: ${error}`,
                });
              }
            }}
            disabled={!canSubmit}
            loading={correctResult.loading}
            text="Save"
            theme="primary"
          />
        </Tooltip>
      </div>
    </AppShell>
  );
};
