import { useGLTF } from "@react-three/drei";
import { useFrame } from "@react-three/fiber";
import { random, sampleSize } from "lodash";
import { useEffect, useMemo, useRef, useState } from "react";
import * as THREE from "three";
import { useEvents } from "../../context/Events";
import { useTimeline } from "../../context/Timeline";
import {
  ANCHOR_Y_DUR,
  PARTICLE_FORMATION_DUR,
  PARTICLE_Y_OFFSET,
  SCALE_DUR,
} from "./data/constants";
import {
  defaultCounts,
  GLTFResult,
  NodeNames,
  treeModels,
} from "./data/treeModels";
import { particleMaterial } from "./helpers/shaders";
import {
  easeInCubic,
  easeInOutCubic,
  interpolateY,
  sigmoid,
} from "./helpers/utils";

const COLOR_1 = "#FFFFFF";
const COLOR_2 = "#bcb8d1";
const COLOR_3 = "#bcb8d1";
const COLOR_4 = "#8c84ad";

const FRACTION_COLOR_1 = 0.55;
const FRACTION_COLOR_2 = 0.2;
const FRACTION_COLOR_3 = 0.2;

const OPACITY_COLOR_1 = 0.7;
const OPACITY_COLOR_2 = 0.8;
const OPACITY_COLOR_3 = 0.8;
const OPACITY_COLOR_4 = 1;

const SCALE_UP = 17;
const SCALE_XS = 0.3 * SCALE_UP;
const SCALE_S = 0.5 * SCALE_UP;
const SCALE_M = 0.8 * SCALE_UP;
const SCALE_L = 1.2 * SCALE_UP;

const FRACTION_XS = 0.5;
const FRACTION_S = 0.3;
const FRACTION_M = 0.2;

const LOW_BRIGHTNESS = 0.1;

const palette = [COLOR_1, COLOR_2, COLOR_3, COLOR_4];
const opacityValues = [
  OPACITY_COLOR_1,
  OPACITY_COLOR_2,
  OPACITY_COLOR_3,
  OPACITY_COLOR_4,
];

const ModelAsParticles = ({
  modelPath,
  nodesName,
  detailOnly,
}: {
  modelPath: string;
  nodesName: NodeNames;
  detailOnly?: boolean;
}) => {
  const { step } = useEvents();
  const { skip } = useTimeline();

  const [revealTime, setRevealTime] = useState<number>(ANCHOR_Y_DUR);
  const [totalTime, setTotalTime] = useState<number>(PARTICLE_FORMATION_DUR);
  const [brightness, setBrightness] = useState(1);

  const particleCount = defaultCounts[nodesName];
  const { nodes } = useGLTF(modelPath) as GLTFResult;

  const createVerticesSamples = () => {
    const positions = nodes[nodesName].geometry.attributes.position
      .array as Float32Array;
    const vertices = [];
    for (let i = 0; i < positions.length; i += 3) {
      vertices.push([positions[i], positions[i + 1], positions[i + 2]]);
    }
    const samples = sampleSize(vertices, particleCount);
    return samples;
  };

  const sampledVertices = useMemo(() => {
    return createVerticesSamples();
  }, []);

  const attributes = useMemo(() => {
    const colors = new Float32Array(particleCount * 3);
    const opacities = new Float32Array(particleCount);
    const y = new Float32Array(particleCount);

    sampledVertices.forEach((vertex, i) => {
      const randomValue = Math.random();
      let colorIndex;

      const fractionCombined1 = FRACTION_COLOR_1 + FRACTION_COLOR_2;
      const fractionCombined2 =
        FRACTION_COLOR_1 + FRACTION_COLOR_2 + FRACTION_COLOR_3;
      const fraction2 = fractionCombined1 >= 1 ? 1 : fractionCombined1;
      const fraction3 = fractionCombined2 >= 1 ? 1 : fractionCombined2;

      if (randomValue < FRACTION_COLOR_1) {
        colorIndex = 0;
      } else if (randomValue < fraction2) {
        colorIndex = 1;
      } else if (randomValue < fraction3) {
        colorIndex = 2;
      } else {
        colorIndex = 3;
      }

      const color = new THREE.Color(palette[colorIndex] as any);
      colors[i * 3] = color.r;
      colors[i * 3 + 1] = color.g;
      colors[i * 3 + 2] = color.b;

      y[i] = vertex[1];
      opacities[i] = opacityValues[colorIndex] as unknown as number;
    });
    return { colors, opacities, y };
  }, [sampledVertices]);

  useEffect(() => {
    if (skip) {
      setRevealTime(0);
      setTotalTime(0);
    }
  }, [skip]);

  useEffect(() => {
    setBrightness(step === -1 ? 1 : LOW_BRIGHTNESS);
  }, [step]);

  const scales = useMemo(() => {
    const thresholds = [
      FRACTION_XS,
      FRACTION_XS + FRACTION_S,
      FRACTION_XS + FRACTION_S + FRACTION_M,
      1,
    ];
    const scales = [SCALE_XS, SCALE_S, SCALE_M, SCALE_L];

    const particleScales = sampledVertices.map((particle) => {
      const randomValue = Math.random();

      let scale = SCALE_XS;
      for (let i = 0; i < thresholds.length; i++) {
        if (randomValue <= thresholds[i]) {
          scale = scales[i] * (detailOnly ? 0.5 : 1);
          break;
        }
      }

      return scale;
    });
    return particleScales;
  }, [particleCount, sampledVertices]);

  // Generate random directions (unit vectors)
  const randomDirections = useMemo(() => {
    return sampledVertices.map(() => {
      const direction = new THREE.Vector3(
        random(-1, 1),
        random(-1, 1),
        random(-1, 1),
      );
      return direction.normalize();
    });
  }, [sampledVertices]);

  // Generate random speeds
  const randomSpeeds = useMemo(() => {
    return sampledVertices.map(() => random(0.5, 1));
  }, [sampledVertices]);

  const instancedMeshRef = useRef<THREE.InstancedMesh>(null!);
  const materialRef = useRef(particleMaterial);
  const runnerYRef = useRef<number>(0);
  const startTimesRef = useRef<Float32Array>(
    new Float32Array(particleCount).fill(-1),
  );

  useFrame(({ clock }) => {
    const elapsed = clock.getElapsedTime();

    const currentYProgress =
      Math.floor(sigmoid(elapsed / revealTime) * 1e3) / 1e3;
    runnerYRef.current = interpolateY(currentYProgress);

    if (instancedMeshRef.current) {
      sampledVertices.forEach((position, i) => {
        const matrix = new THREE.Matrix4();

        let motionScalar = 0.02;
        let scale = scales[i];

        if (elapsed < totalTime) {
          // Check if runnerY surpasses instanceY and set start time
          if (
            runnerYRef.current > attributes.y[i] &&
            startTimesRef.current[i] === -1
          ) {
            startTimesRef.current[i] = elapsed;
          }

          if (startTimesRef.current[i] !== -1) {
            const animatedTime = elapsed - startTimesRef.current[i];
            const progress = Math.min(animatedTime / 3.0 /*sec*/, 1.0);
            motionScalar = THREE.MathUtils.lerp(
              0.3, // how far it's off
              0.02, // Final scale
              easeInOutCubic(progress),
            );
          }

          const animatedTime = elapsed - startTimesRef.current[i];
          const scaleProgress = Math.min(animatedTime / SCALE_DUR, 1.0);
          scale = THREE.MathUtils.lerp(
            scales[i] * 0.1,
            scales[i],
            easeInCubic(scaleProgress),
          );
        }

        const floatOffset =
          Math.sin(elapsed * randomSpeeds[i] + i * 0.2) * motionScalar;
        const offset = randomDirections[i].clone().multiplyScalar(floatOffset);

        matrix.setPosition(
          position[0] + offset.x,
          position[1] + offset.y,
          position[2] + offset.z,
        );

        if (elapsed < SCALE_DUR) {
          const animatedTime = elapsed - startTimesRef.current[i];
          const scaleProgress = Math.min(animatedTime / SCALE_DUR, 1.0);
          scale = THREE.MathUtils.lerp(
            scales[i] * 0.1,
            scales[i],
            easeInOutCubic(scaleProgress),
          );
        }

        matrix.scale(new THREE.Vector3(scale, scale, scale));
        instancedMeshRef.current.setMatrixAt(i, matrix);
      });
      instancedMeshRef.current.instanceMatrix.needsUpdate = true;

      materialRef.current.uniforms.uTime.value = elapsed;
      materialRef.current.uniforms.brightness.value = detailOnly
        ? LOW_BRIGHTNESS
        : brightness;
      materialRef.current.uniforms.runnerY.value = runnerYRef.current;
    }
  });

  return (
    <group>
      <instancedMesh
        ref={instancedMeshRef}
        args={[undefined, undefined, sampledVertices.length]} // geometry, material, count
        material={materialRef.current}
      >
        <sphereGeometry args={[0.001, 16, 16]}>
          <instancedBufferAttribute
            attach="attributes-instanceColor"
            args={[attributes.colors, 3]}
          />
          <instancedBufferAttribute
            attach="attributes-instanceY"
            args={[attributes.y, 1]}
          />
          <instancedBufferAttribute
            attach="attributes-instanceOpacity"
            args={[attributes.opacities, 1]}
          />
        </sphereGeometry>
      </instancedMesh>
    </group>
  );
};

export const Particles = () => {
  return (
    <group
      dispose={null}
      position={[0, PARTICLE_Y_OFFSET, 0]}
    >
      {treeModels.map(({ modelPath, nodeName, sizeRange }, index: number) => (
        <ModelAsParticles
          key={index}
          modelPath={modelPath}
          nodesName={nodeName}
        />
      ))}
    </group>
  );
};

export const AdditionalBackgroundParticles = () => {
  const model = treeModels[0];
  const scale = 1.5;
  return (
    <group
      position={[0, PARTICLE_Y_OFFSET + 0.5, 0]}
      scale={[scale, scale, scale]}
    >
      <ModelAsParticles
        modelPath={model.modelPath}
        nodesName={model.nodeName}
        detailOnly
      />
    </group>
  );
};

treeModels.forEach(({ modelPath }) => {
  useGLTF.preload(modelPath);
});
