import React, { useRef, useEffect } from "react";
import PropTypes from "prop-types";
import * as d3 from "d3";
import { dataFieldsMetaTemplate } from "../../constants";

const oncoprintColors = {
  No_Variant: d3.rgb("rgb(204, 204, 204)"),
  Splice_Site: d3.rgb("rgb(240, 224, 130)"),
  In_Frame_Del: d3.rgb("rgb(115, 198, 107)"),
  In_Frame_Ins: d3.rgb("rgb(22, 160, 9)"),
  Missense_Mutation: d3.rgb("rgb(153, 187, 217)"),
  Nonsense_Mutation: d3.rgb("rgb(52, 119, 180)"),
  Frame_Shift_Del: d3.rgb("rgb(156, 138, 181)"),
  Frame_Shift_Ins: d3.rgb("rgb(90, 60, 133)"),
  Multiple_Variants: d3.rgb("rgb(0, 0, 0)"),
};

const assessVariant = (variant) => {
  const mapping = {
    No_Variant: 0,
    Splice_Site: 1,
    In_Frame_Del: 2,
    In_Frame_Ins: 3,
    Missense_Mutation: 4,
    Nonsense_Mutation: 5,
    Frame_Shift_Del: 6,
    Frame_Shift_Ins: 7,
    Multiple_Variants: 8,
  };
  return mapping[variant];
};

const Chart = ({
  data,
  geneList,
  sortBy,
  selectedGene,
  selectedDataField,
  options,
  saveCohort,
}) => {
  const ref = useRef(null);
  useEffect(() => {
    if (data && geneList && options && ref.current) {
      let svg = d3.select(ref.current); // eslint-disable-line

      saveCohort; // eslint-disable-line

      const { allIds, rawVariantData, dataFields: dataFieldsAllSamples } = data;

      if (allIds.length === 0) {
        return;
      }

      // Keep only samples that have DNA data
      const dataFields = dataFieldsAllSamples.filter((d) =>
        allIds.includes(d.ids)
      );

      // Get the names of the user specified data fields (ignore ids and Cohort)
      const dataFieldNames = Object.keys(dataFields[0]).filter(
        (d) => !["ids", "Cohort"].includes(d)
      );

      // Instantiate an array with meta data for data fields
      const dataFieldsMeta = dataFieldsMetaTemplate
        .filter((d) => dataFieldNames.includes(d.name))
        .map((d) => ({ ...d, color: d.color(dataFields) }));

      // Group variants by sample id + gene
      const variantData = Array.from(
        d3.group(rawVariantData, (d) => `${d.ids}_${d.user_specified}`).values()
      );

      // Get unique genes
      const uniqueGenes = Array.from(
        new Set(rawVariantData.map((d) => d.user_specified))
      );

      // Get unique sample ids and those without any mutations
      const uniqueSamples = Array.from(
        new Set(rawVariantData.map((d) => d.ids))
      );
      const samplesWithoutVariants = allIds.filter(
        (id) => !uniqueSamples.includes(id)
      );

      // Layout
      const containerDim = document
        .getElementById("mutation-chart-container")
        .getBoundingClientRect();

      const margin = {
        top: 40,
        right: 150,
        bottom: 80,
        left: 150,
      };

      const rectHSpace = Math.min(
        10 * 1.61803398875,
        (containerDim.height - margin.top - margin.bottom) /
          (dataFieldsMeta.length + uniqueGenes.length)
      );
      const rectWSpace = Math.min(
        10,
        (containerDim.width - margin.left - margin.right) / allIds.length
      );
      const rectHeight = rectHSpace * 0.85;
      const rectWidth = rectWSpace * 0.85;

      const height =
        margin.top +
        dataFieldsMeta.length * rectHSpace +
        uniqueGenes.length * rectHSpace +
        margin.bottom;
      const width = margin.left + allIds.length * rectWSpace + margin.right;

      // Sort gene names
      let sortedGeneNames;
      if (
        sortBy === "gene" &&
        uniqueGenes.includes(selectedGene[selectedGene.type])
      ) {
        sortedGeneNames = d3
          .sort(
            Array.from(
              d3.rollup(
                variantData,
                (v) => v.length,
                (d) => d[0].user_specified
              )
            ),
            (a, b) => d3.descending(a[1], b[1])
          )
          .map(([name, _unused]) => name);
        sortedGeneNames.splice(
          sortedGeneNames.indexOf(selectedGene[selectedGene.type]),
          1
        );
        sortedGeneNames.unshift(selectedGene[selectedGene.type]);
      } else {
        sortedGeneNames = d3
          .sort(
            Array.from(
              d3.rollup(
                variantData,
                (v) => v.length,
                (d) => d[0].user_specified
              )
            ),
            (a, b) => d3.descending(a[1], b[1])
          )
          .map(([name, _unused]) => name);
      }

      // Sort samples
      const variantMap = new Map();
      variantData
        .map((d) =>
          d.length === 1 ? d[0] : { ...d[0], variant: "Multiple_Variants" }
        )
        .forEach((d) => variantMap.set(`${d.ids}_${d.user_specified}`, d));

      const variantMatrix = uniqueSamples.map((id) =>
        sortedGeneNames.map((gene) => ({
          id,
          gene,
          variant: variantMap.has(`${id}_${gene}`)
            ? variantMap.get(`${id}_${gene}`).variant
            : "No_Variant",
        }))
      );

      uniqueGenes.forEach((unused, i) =>
        variantMatrix.sort((a, b) =>
          d3.ascending(
            assessVariant(b[uniqueGenes.length - i - 1].variant),
            assessVariant(a[uniqueGenes.length - i - 1].variant)
          )
        )
      );
      let sortedSamples = variantMatrix
        .map((v) => v[0].id)
        .concat(samplesWithoutVariants);

      if (sortBy === "dataField") {
        const { cmp } = dataFieldsMeta.filter(
          (d) => d.id === selectedDataField
        )[0];
        sortedSamples = dataFields
          .sort(
            (a, b) =>
              sortedSamples.indexOf(a.ids) - sortedSamples.indexOf(b.ids)
          )
          .sort(cmp)
          .map((d) => d.ids);
      }

      // Scales
      const scaleDataFields = d3
        .scaleBand()
        .domain(dataFieldsMeta.map((d) => d.name))
        .range([margin.top, margin.top + dataFieldsMeta.length * rectHSpace]);

      const scaleSample = d3
        .scaleBand()
        .domain(sortedSamples)
        .range([margin.left, margin.left + allIds.length * rectWSpace]);

      const scaleGene = d3
        .scaleBand()
        .domain(sortedGeneNames)
        .range([
          margin.top + dataFieldsMeta.length * rectHSpace,
          margin.top +
            dataFieldsMeta.length * rectHSpace +
            sortedGeneNames.length * rectHSpace,
        ]);

      const createScaleDataField = (m) => {
        if (m.type === "categorical") {
          return (_unused) => scaleDataFields(m.id);
        }
        if (m.type === "linear") {
          return d3
            .scaleLinear()
            .domain([
              d3.min(dataFields.map((r) => r[m.name])),
              d3.max(dataFields.map((r) => r[m.name])),
            ])
            .range([0, rectHeight]);
        }
        if (m.type === "log") {
          return d3
            .scaleLog()
            .domain([
              d3.min(dataFields.map((r) => r[m.name])),
              d3.max(dataFields.map((r) => r[m.name])),
            ])
            .range([0, rectHeight]);
        }
        return undefined;
      };

      const dataFieldsMetaFinal = dataFieldsMeta.map((m) => ({
        ...m,
        scale: createScaleDataField(m),
      }));

      // Axes
      const dataFieldAxis = d3
        .axisLeft(scaleDataFields)
        .tickSize(0)
        .tickPadding(4);

      const geneAxis = d3.axisLeft(scaleGene).tickSize(0).tickPadding(4);

      svg.selectAll("g.data-field").remove();

      dataFieldsMetaFinal.forEach((m) => {
        svg
          .append("g")
          .attr("class", `data-field ${m.id}`)
          .selectAll("rect")
          .data(dataFields)
          .join("rect")
          .attr("fill-opacity", 0.9)
          .attr("width", rectWidth)
          .attr("height", (d) =>
            m.type === "categorical" ? rectHeight : m.scale(d[m.name])
          )
          .attr("x", (d) => scaleSample(d.ids))
          .attr(
            "y",
            (d) =>
              scaleDataFields(m.name) +
              (m.type === "categorical" ? 0 : rectHeight - m.scale(d[m.name]))
          )
          .attr("fill", (d) => m.color(d[m.name]))
          .style("stroke-width", 0)
          .style("stroke", (d) => m.color(d[m.name]));

        svg
          .append("g")
          .attr("class", `data-field ${m.id}`)
          .selectAll("text")
          .data(dataFields.filter((d) => d[m.name] === null))
          .join("text")
          .style("font-size", 8)
          .attr("x", (d) => scaleSample(d.ids))
          .attr("y", scaleDataFields(m.name) + rectHeight / 1.5)
          .text("x");
      });

      const dataFieldRect = svg.selectAll("g.data-field").selectAll("rect");

      svg
        .append("g")
        .attr("class", "axis data-field")
        .attr("transform", `translate(${margin.left}, 0)`)
        .call(dataFieldAxis)
        .attr("font-size", 12) // this has to be after the call to axis func
        .call((g) => g.select(".domain").remove());

      svg.append("g").attr("class", "hover-panel");

      /* Save Cohort Button */

      const saveCohortButton = svg
        .select("#save-cohort-button")
        .select("rect")
        .attr("x", 150)
        .attr("y", 5)
        .attr("height", 20)
        .attr("width", 120)
        .attr("rx", 4)
        .style("fill", "lightgray")
        .style("stroke", "black")
        .style("stroke-width", "0.5")
        .style("cursor", "not-allowed");

      const saveCohortText = svg
        .select("#save-cohort-button")
        .select("text")
        .attr("x", 150 + 120 + 5)
        .attr("y", 5 + 15);

      saveCohortText.text("No samples selected.");

      /* Interaction */

      const hoverPanel = svg
        .selectAll("g.hover-panel")
        .attr("font-family", "monospace")
        .attr("font-size", 12);

      hoverPanel.append("text");

      // const panelBackground = svg
      //   .selectAll("g.hover-panel")
      //   .append("rect")
      //   .attr("x", 0)
      //   .attr("y", 0)
      //   .attr("width", 0)
      //   .attr("height", 0)
      //   .attr("fill", "none");

      let persist = false;

      const rectangles = svg
        .selectAll("g.rectangles")
        .selectAll("rect")
        .data(variantData)
        .join("rect")
        .attr("fill-opacity", 0.9)
        .attr("width", rectWidth)
        .attr("height", rectHeight)
        .attr("x", (d) => scaleSample(d[0].ids))
        .attr("y", (d) => scaleGene(d[0].user_specified))
        .attr("fill", (d) =>
          d.length > 1 ? "black" : oncoprintColors[d[0].variant]
        )
        .style("stroke-width", 0)
        .style("stroke", (d) =>
          d.length > 1 ? "black" : oncoprintColors[d[0].variant]
        );

      const transformRect = (e) =>
        `translate(${e.currentTarget.x.baseVal.value + 12},${
          e.currentTarget.y.baseVal.value + 12
        })`;

      svg.selectAll("g.canvas").remove();
      if (options.interaction === "tooltip") {
        rectangles
          .on("mouseover", (e, d) => {
            if (persist) {
              return;
            }
            hoverPanel.attr("transform", transformRect(e));
            // let totalRows = 0;
            // let maxRowLength = 0;
            d.forEach((x, i) => {
              // totalRows += 1;
              // maxRowLength = d3.max([
              //   maxRowLength,
              //   x.ensembl.length + x.variant.length,
              // ]);
              hoverPanel
                .append("text")
                .attr("x", 6)
                .attr("y", (i + 1) * 12)
                .text(`${x.ensembl} | ${x.variant} | ${x.hgvsp_short}`);
            });
            d3.select(e.currentTarget).style("stroke-width", 0.9);
            // panelBackground
            //   .attr("height", totalRows * 8)
            //   .attr("width", maxRowLength * 6)
            //   .attr("fill", "#f2f2f2");
          })
          .on("mouseout", (_unused, _unused2) => {
            if (persist) {
              return;
            }
            hoverPanel.attr("transform", `translate(0,0)`);
            // panelBackground.attr("fill", "none");
            hoverPanel.selectAll("text").remove();
            rectangles.style("stroke-width", 0);
            dataFieldRect.style("stroke-width", 0);
          })
          .on("click", (e, d) => {
            if (persist) {
              hoverPanel.attr("transform", `translate(0,0)`);
              // panelBackground.attr("fill", "none");
              hoverPanel.selectAll("text").remove();
              rectangles.style("stroke-width", 0);
              dataFieldRect.style("stroke-width", 0);
            } else {
              hoverPanel.attr("transform", transformRect(e));
              hoverPanel.selectAll("text").remove();
              // let totalRows = 0;
              // let maxRowLength = 0;
              d.forEach((x, i) => {
                // totalRows += 1;
                // maxRowLength = d3.max([
                //   maxRowLength,
                //   x.ensembl.length + x.variant.length,
                // ]);
                hoverPanel
                  .append("text")
                  .attr("x", 6)
                  .attr("y", (i + 1) * 12)
                  .text(`${x.ensembl} | ${x.variant} | ${x.hgvsp_short}`);
              });
              d3.select(e.currentTarget).style("stroke-width", 0.9);
              // panelBackground
              //   .attr("height", totalRows * 8)
              //   .attr("width", maxRowLength * 6)
              //   .attr("fill", "#f2f2f2");
            }
            persist = !persist;
          });

        dataFieldRect
          .on("mouseover", (e, d) => {
            if (persist) {
              return;
            }
            hoverPanel.attr("transform", transformRect(e));
            // let totalRows = 0;
            // let maxRowLength = 0;
            let j = 0;
            Object.entries(d).forEach(([k, v], _unused) => {
              if (["ids", "Cohort"].includes(k)) {
                return;
              }
              j += 1;
              // totalRows += 1;
              // maxRowLength = d3.max([maxRowLength, k.length + v.length]);
              hoverPanel
                .append("text")
                .attr("x", 6)
                .attr("y", (j + 1) * 12)
                .text(`${k} | ${v}`);
            });
            d3.select(e.currentTarget).style("stroke-width", 0.9);
            // panelBackground
            //   .attr("height", totalRows * 8)
            //   .attr("width", maxRowLength * 6)
            //   .attr("fill", "#f2f2f2");
          })
          .on("mouseout", (_unused, _unused2) => {
            if (persist) {
              return;
            }
            hoverPanel.attr("transform", `translate(0,0)`);
            // panelBackground.attr("fill", "none");
            hoverPanel.selectAll("text").remove();
            rectangles.style("stroke-width", 0);
            dataFieldRect.style("stroke-width", 0);
          })
          .on("click", (e, d) => {
            if (persist) {
              hoverPanel.attr("transform", `translate(0,0)`);
              // panelBackground.attr("fill", "none");
              hoverPanel.selectAll("text").remove();
              rectangles.style("stroke-width", 0);
              dataFieldRect.style("stroke-width", 0);
            } else {
              hoverPanel.attr("transform", transformRect(e));
              // let totalRows = 0;
              // let maxRowLength = 0;
              let j = 0;
              Object.entries(d).forEach(([k, v], _unused) => {
                if (["ids", "Cohort"].includes(k)) {
                  return;
                }
                j += 1;
                // totalRows += 1;
                // maxRowLength = d3.max([maxRowLength, k.length + v.length]);
                hoverPanel
                  .append("text")
                  .attr("x", 6)
                  .attr("y", (j + 1) * 12)
                  .text(`${k} | ${v}`);
              });
              d3.select(e.currentTarget).style("stroke-width", 0.9);
              // panelBackground
              //   .attr("height", totalRows * 8)
              //   .attr("width", maxRowLength * 6)
              //   .attr("fill", "#f2f2f2");
            }
            persist = !persist;
          });
      } else {
        svg.append("g").attr("class", "canvas");
        const brushedABC = ({ selection }) => {
          if (selection) {
            const [x0, x1] = selection;
            rectangles
              .attr("fill-opacity", 0.1)
              .filter(
                (d) =>
                  x0 <= scaleSample(d[0].ids) &&
                  x1 >= scaleSample(d[0].ids) + rectWidth
              )
              .attr("fill-opacity", 1);
            const brushRectData = dataFieldRect
              .attr("fill-opacity", 0.1)
              .filter(
                (d) =>
                  x0 <= scaleSample(d.ids) &&
                  x1 >= scaleSample(d.ids) + rectWidth
              )
              .attr("fill-opacity", 1)
              .data();
            const brushIds = Array.from(
              new Set(brushRectData.flat().map((d) => d.ids))
            );
            saveCohortText.text(`Selected ${brushIds.length} samples.`);
            saveCohortButton.on("click", () => {
              saveCohort(
                brushIds.map((id) => id.split("+")[0]),
                brushIds.map((id) => id.split("+")[1]),
                "Mutation",
                []
              );
            });
            saveCohortButton.style("cursor", "pointer");
          } else {
            rectangles.attr("fill-opacity", 0.9);
            dataFieldRect.attr("fill-opacity", 0.9);
            saveCohortButton.on("click", () => {});
            saveCohortButton.style("cursor", "not-allowed");
          }
        };

        const brush = d3
          .brushX()
          .extent([
            [margin.left - 2, margin.top - 2],
            [width - margin.right + 2, height - margin.bottom + 2],
          ])
          .on("start brush end", brushedABC);

        svg.selectAll("g.canvas").call(brush);
      }

      svg
        .selectAll("g.axis.gene")
        .attr("transform", `translate(${margin.left}, 0)`)
        .call(geneAxis)
        .attr("font-size", 12) // this has to be after the call to geneAxis
        .call((g) => g.select(".domain").remove());
    }
  }, [data, geneList, options]);

  return (
    <div id="mutation-chart-container" className="analysis-chart-container">
      <svg ref={ref} style={{ height: "100%", width: "100%" }}>
        <g id="save-cohort-button">
          <rect />
          <text />
        </g>
        <g className="data-field" />
        <g className="rectangles" />
        <g className="axis gene" />
      </svg>
    </div>
  );
};

Chart.propTypes = {
  data: PropTypes.object.isRequired,
  geneList: PropTypes.array.isRequired,
  sortBy: PropTypes.string.isRequired,
  selectedGene: PropTypes.object.isRequired,
  selectedDataField: PropTypes.string.isRequired,
  options: PropTypes.object.isRequired,
  saveCohort: PropTypes.func.isRequired,
};

export default Chart;
