import React, { useMemo, useRef, useState, useEffect, useCallback } from 'react';
import { useFrame } from '@react-three/fiber';
import * as THREE from 'three';
import PropTypes from 'prop-types';

// Solar Particle component for individual flares
const SolarParticle = ({ position, scale }) => {
  const meshRef = useRef();
  const [alive, setAlive] = useState(true);
  const initialDistance = useRef(null);
  
  // Create a reusable vector3 for calculations
  const direction = useMemo(() => new THREE.Vector3(), []);
  
  // Create geometry outside component to be reused
  const geometry = useMemo(() => new THREE.SphereGeometry(54817, 16, 16), []);
  
  // Simplified shader for solid appearance
  const particleShader = useMemo(() => ({
    uniforms: {
      time: { value: 0 },
    },
    vertexShader: `
      varying vec3 vNormal;
      
      void main() {
        vNormal = normal;
        gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
      }
    `,
    fragmentShader: `
      varying vec3 vNormal;
      uniform float time;

      void main() {
        vec3 baseColor = vec3(1.0, 0.8, 0.3);  // Brighter orange-yellow
        
        // Add subtle radial gradient
        float intensity = dot(vNormal, vec3(0.0, 0.0, 1.0)) * 0.15 + 0.85;
        vec3 finalColor = baseColor * intensity * 1.5; // Increased overall brightness
        
        gl_FragColor = vec4(finalColor, 0.7); // Increased opacity
      }
    `
  }), []);

  useFrame((state) => {
    if (meshRef.current && alive) {
      const mesh = meshRef.current;
      const pos = mesh.position;
      
      // Update shader time uniform
      mesh.material.uniforms.time.value = state.clock.getElapsedTime();
      
      // Set initial distance on first frame
      if (initialDistance.current === null) {
        initialDistance.current = pos.length();
      }
      
      // Calculate direction without creating new vectors
      direction.copy(pos).normalize();
      pos.add(direction.multiplyScalar(1000));
      
      // Faster rotation for optimization
      mesh.rotation.x += 0.02;
      mesh.rotation.y += 0.02;
      
      // Calculate distance-based opacity decay
      const currentDistance = pos.length();
      const distanceRatio = currentDistance / initialDistance.current;
      
      // Changed from 1.5 to 1.2 to make particles die sooner
      if (distanceRatio > 1.05) setAlive(false);
    }
  });

  if (!alive) return null;

  return (
    <mesh 
      ref={meshRef} 
      position={position} 
      scale={scale}
      renderOrder={2}
      geometry={geometry}
    >
      <shaderMaterial 
        attach="material"
        {...particleShader}
        transparent={true}
        depthWrite={false}
        blending={THREE.AdditiveBlending}
        opacity={0.99}
      />
    </mesh>
  );
};

const Sun = ({ isSelected }) => {
  const [particles, setParticles] = useState([]);
  const particleCount = useRef(0);
  const lastUpdate = useRef(0);
  
  // Cache calculations
  const sunRadius = 696340;
  const twoPI = Math.PI * 2;
  
  // Reusable vectors for particle position calculation
  const positionVector = useMemo(() => new THREE.Vector3(), []);

  const calculateParticlePosition = useCallback(() => {
    const theta = Math.random() * twoPI;
    const phi = Math.acos(2 * Math.random() - 1);
    
    positionVector.set(
      sunRadius * Math.sin(phi) * Math.cos(theta),
      sunRadius * Math.sin(phi) * Math.sin(theta),
      sunRadius * Math.cos(phi)
    );
    
    return [positionVector.x, positionVector.y, positionVector.z];
  }, [positionVector]);

  useFrame((state) => {
    const time = state.clock.getElapsedTime();
    
    // Only spawn particles if the Sun is selected
    if (isSelected && Math.random() < 0.42) {
      setParticles(current => [
        ...current,
        {
          id: particleCount.current++,
          position: calculateParticlePosition(),
          scale: [0.5, 0.5, 0.5]
        }
      ]);
    }

    // Keep cleaning up particles even when not selected
    if (particles.length > 200) {
      setParticles(current => current.slice(1));
    }
  });

  const jupiterShader = useMemo(() => {
    return {
      uniforms: {
        time: { value: 0 },
        modelMatrix: { value: new THREE.Matrix4() },
      },
      vertexShader: `
        varying vec2 vUv;
        varying vec3 vNormal;
        varying vec3 vPosition;
        
        void main() {
          // Apply 30-degree rotation around X axis
          float angle = radians(30.0);
          mat3 rotationMatrix = mat3(
            1.0, 0.0, 0.0,
            0.0, cos(angle), -sin(angle),
            0.0, sin(angle), cos(angle)
          );
          
          vec3 rotatedPosition = rotationMatrix * position;
          vec3 rotatedNormal = rotationMatrix * normal;
          
          vUv = uv;
          vNormal = rotatedNormal;
          vPosition = rotatedPosition;
          gl_Position = projectionMatrix * modelViewMatrix * vec4(rotatedPosition, 1.0);
        }
      `,
      fragmentShader: `
        varying vec2 vUv;
        varying vec3 vNormal;
        varying vec3 vPosition;
        uniform float time;
        uniform mat4 modelMatrix;

        const float PI = 3.14159265359;

        // Noise functions
        float random(vec2 st) {
          return fract(sin(dot(st.xy, vec2(12.9898,78.233))) * 43758.5453123);
        }
        
        // Value noise
        float noise(vec2 st) {
          vec2 i = floor(st);
          vec2 f = fract(st);
          float a = random(i);
          float b = random(i + vec2(1.0, 0.0));
          float c = random(i + vec2(0.0, 1.0));
          float d = random(i + vec2(1.0, 1.0));
          vec2 u = f * f * (1.0 - 1.0 * f);
          return mix(a, b, u.x) + (c - a)* u.y * (1.0 - u.x) + (d - b) * u.x * u.y;
        }

        void main() {
          vec3 baseColor = vec3(1.0, 0.8, 0.3); // Even brighter orange sun base color
          
          // Keep the wave calculations
          float angle = atan(vPosition.z, vPosition.x);
          float wave1 = sin(angle * 10.0 + time * 0.4) * 0.04;
          float wave2 = sin(angle * 8.0 + time * 0.3 + PI * 1.5) * 0.03;
          float wave3 = sin(angle * 6.0 - time * 0.5 + PI * 0.8) * 0.02;
          float modifiedY = vUv.y + wave1 + wave2 + wave3;
          
          float bands = noise(vec2(modifiedY * 1.0, time * 0.1)) * 0.7;
          bands += noise(vec2(modifiedY * 30.0, time * 0.05)) * 0.3;
          
          vec3 bandColor;
          
          // Brighter color bands for sun
          if (modifiedY > 0.75) {
              bandColor = mix(
                  vec3(1.0, 0.98, 0.9),   // Nearly white
                  vec3(1.0, 1.0, 1.0),    // Pure white
                  bands
              );
          } else if (modifiedY > 0.5) {
              bandColor = mix(
                  vec3(1.0, 0.8, 0.4),    // Brighter deep orange
                  vec3(1.0, 0.95, 0.5),   // Brighter bright orange
                  bands
              );
          } else if (modifiedY > 0.25) {
              bandColor = mix(
                  vec3(1.0, 0.6, 0.3),    // Brighter deep red-orange
                  vec3(1.0, 0.8, 0.4),    // Brighter bright orange
                  bands
              );
          } else {
              bandColor = mix(
                  vec3(1.0, 0.5, 0.3),    // Brighter dark orange-red
                  vec3(1.0, 0.7, 0.3),    // Brighter medium orange
                  bands
              );
          }
          
          // Intensify solar flare effect
          bandColor *= mix(
              vec3(1.4, 1.3, 1.1),    // Brighter white-yellow
              vec3(1.2, 1.0, 0.8),    // Brighter warm orange
              sin(modifiedY * 12.0) * 0.5 + 0.5
          );
          
          vec3 finalColor = mix(bandColor, baseColor, 0.1) * 1.3; // Increased overall brightness
          
          // Calculate lighting
          vec3 worldPos = (modelMatrix * vec4(vPosition, 1.0)).xyz;
          vec3 toSun = normalize(-worldPos);
          vec3 worldNormal = normalize(mat3(modelMatrix) * vNormal);
          float dotProduct = dot(worldNormal, toSun);
          
          // Reduced darkening at edges
          if (dotProduct < 0.3) {
              finalColor *= 0.9; // Changed from 0.8 to 0.9 for less darkening
          }
          
          gl_FragColor = vec4(finalColor, 1.0);
        }
      `
    };
  }, []);

  return (
    <group>
      {/* Main sun sphere */}
      <mesh renderOrder={1}>
        <sphereGeometry args={[696340, 64, 64]} />
        <shaderMaterial 
          attach="material"
          {...jupiterShader}
          side={THREE.FrontSide}
          transparent={true}
          depthWrite={false}
          depthTest={true}
          emissive={new THREE.Color(1, 1, 0.8)}
          emissiveIntensity={5.0}
          opacity={0.99}
        />
      </mesh>

      {/* Solar particles */}
      {particles.map(particle => (
        <SolarParticle
          key={particle.id}
          position={particle.position}
          scale={particle.scale}
          depthWrite={true}
          depthTest={true}
        />
      ))}
    </group>
  );
};

// Add prop types for clarity
Sun.propTypes = {
  isSelected: PropTypes.bool
};

Sun.defaultProps = {
  isSelected: false
};

export default Sun;
