import { Matrix, Quaternion, Vector2, Vector3 } from "@babylonjs/core";

export var useCorrectedLHConversion = false;

declare module "@babylonjs/core/Maths/math.vector" {
  interface Vector3 {
    asLHVector3(): Vector3;
  }
  interface Quaternion {
    asLHQuaternion(): Quaternion;
  }
}

Vector3.prototype.asLHVector3 = function (): Vector3 {
    return new Vector3(-this.x, this.y, this.z);
};

Quaternion.prototype.asLHQuaternion = function (): Quaternion {
  return useCorrectedLHConversion ? 
    new Quaternion(-this.x, this.y, this.z, this.w) :
    new Quaternion(this.x, -this.y, -this.z, this.w);
};

/**
 * Computes the intersection of an undistorted image point with the ground plane at `groundFloorHeight`
 * in a Y-up coordinate system.
 * 
 * @param undistortedPoint - The undistorted 2D image point.
 * @param cx - Principal point x-coordinate.
 * @param cy - Principal point y-coordinate.
 * @param fx - Focal length in the x-direction.
 * @param fy - Focal length in the y-direction.
 * @param cameraRotation - Camera rotation quaternion (world to camera).
 * @param cameraTranslation - Camera position in world coordinates.
 * @param groundFloorHeight - The height of the ground plane in world space (Y-up system).
 * @returns The intersection point in world coordinates or `undefined` if no intersection.
 */
export function computeGroundIntersection(
  undistortedPoint: Vector2, cx: number, cy: number, fx: number, fy: number,
  cameraRotation: Quaternion, cameraTranslation: Vector3, groundFloorHeight: number
): Vector3 | undefined {

  // Convert quaternion to rotation matrix
  const rotationMatrix = new Matrix();
  cameraRotation.toRotationMatrix(rotationMatrix);

  // Step 1: Convert to 3D camera coordinates (Assume the camera is looking along -Z)
  let rayCamera = new Vector3(
    (undistortedPoint.x - cx) / fx,
    (undistortedPoint.y - cy) / fy,
    -1 // Negative Z because the camera looks along -Z in a Y-up system
  ).normalize();

  // Step 2: Transform ray to world coordinates
  const rayWorld = Vector3.TransformNormal(rayCamera, rotationMatrix);
  const originWorld = cameraTranslation; // Camera position in world coordinates

  // Step 3: Compute intersection with the ground plane at `groundFloorHeight`
  const heightDifference = originWorld.y - groundFloorHeight; // Using Y-up coordinate system
  if (rayWorld.y === 0) return undefined; // Ray is parallel to the ground

  const t = -heightDifference / rayWorld.y;
  if (t < 0) return undefined; // Ray points upwards, no ground intersection

  // Compute intersection point
  return new Vector3(
    originWorld.x + t * rayWorld.x,
    groundFloorHeight, // Ground plane is at given Y height
    originWorld.z + t * rayWorld.z
  );
}


export function projectPointToGroundPlane(
  imagePoint: Vector2,
  aspectRatio: number,
  cameraDiagFovDeg: number,
  cameraPosition: Vector3,
  cameraOrientation: Quaternion,
  yGround: number
): Vector3 {
  // Calculate the diagonal field of view in radians
  const fovDiagonalRad = Math.PI * cameraDiagFovDeg / 180.0;

  // Calculate the vertical field of view based on the diagonal field of view
  const fovVerticalRad = 2.0 * Math.atan(Math.tan(fovDiagonalRad / 2) / Math.sqrt(1 + aspectRatio * aspectRatio));

  const fovHorizontalRad = 2.0 * Math.atan(Math.tan(fovDiagonalRad / 2) * aspectRatio / Math.sqrt(1 + aspectRatio * aspectRatio));

  // Calculate the normalized device coordinates (NDC) of the point
  const ndcX = 2.0 * imagePoint.x - 1.0;
  const ndcY = -2.0 * imagePoint.y + 1.0;

  // Calculate the direction vector from camera to the point in camera space
  const directionCameraSpaceX = Math.tan(fovHorizontalRad / 2) * ndcX;
  const directionCameraSpaceY = Math.tan(fovVerticalRad / 2) * ndcY;
  const directionCameraSpace = new Vector3(directionCameraSpaceX, directionCameraSpaceY, -1);

  // Convert camera orientation quaternion to rotation matrix
  const rotationMatrix = new Matrix();
  cameraOrientation.toRotationMatrix(rotationMatrix);

  // Rotate the direction vector to world space
  const directionWorldSpace = Vector3.TransformCoordinates(directionCameraSpace, rotationMatrix);

  // Calculate the distance from camera to ground plane
  const distanceToGroundPlane = (cameraPosition.y - yGround) / directionWorldSpace.y;

  // Calculate the intersection point with the ground plane
  const intersectionPoint = cameraPosition.subtract(directionWorldSpace.scale(distanceToGroundPlane));

  return intersectionPoint;
}

// export function calculateCameraMatrix(fovDiagonalDeg: number, width: number, height: number): cv.Mat {
//     // Convert FOV from degrees to radians
//     const fovDiagonalRad = (fovDiagonalDeg * Math.PI) / 180;
  
//     // Calculate the diagonal length of the image
//     const imageDiagonal = Math.sqrt(width ** 2 + height ** 2);
  
//     // Calculate the focal length using the diagonal FOV
//     const focalLength = imageDiagonal / (2 * Math.tan(fovDiagonalRad / 2));
  
//     // Principal point is centered
//     const cx = width / 2;
//     const cy = height / 2;
  
//     // Assuming square pixels, fx = fy
//     const fx = focalLength;
//     const fy = focalLength;
  
//     // Construct the camera matrix
//     const cameraMatrixData = [
//       [fx, 0, cx],
//       [0, fy, cy],
//       [0, 0, 1],
//     ];
  
//     const cameraMatrix = cv.matFromArray(3, 3, cv.CV_64F, cameraMatrixData.flat());
  
//     return cameraMatrix;
//   }
  

  export function undistortPointBrownConrady(
    point: Vector2, width: number, height: number, diagonalFovDegrees: number,
    k1: number, k2: number, k3: number, p1: number, p2: number): Vector2 {
    
    const diagonalFovRadians = diagonalFovDegrees * (Math.PI / 180);
    const imageDiagonal = Math.sqrt(width * width + height * height);
    const f = imageDiagonal / (2 * Math.tan(diagonalFovRadians / 2));

    const fx = f;
    const fy = f;

    const cx = width / 2;
    const cy = height / 2;

    // Initial guess: the point itself
    let x = (point.x - cx) / fx;
    let y = (point.y - cy) / fy;

    // Iterative refinement to undistort
    const maxIterations = 3;
    for (let i = 0; i < maxIterations; i++) {
        let r2 = x * x + y * y;
        let radialDistortion = 1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2;

        let deltaX = 2 * p1 * x * y + p2 * (r2 + 2 * x * x);
        let deltaY = p1 * (r2 + 2 * y * y) + 2 * p2 * x * y;

        let xDistorted = x * radialDistortion + deltaX;
        let yDistorted = y * radialDistortion + deltaY;

        x = x - (xDistorted - x);
        y = y - (yDistorted - y);
    }

    return new Vector2(x * fx + cx, y * fy + cy);
}

export function undistortPointFisheye4(
  point: Vector2, cx: number, cy: number, fx: number, fy: number,
  k1: number, k2: number, k3: number, k4: number
): Vector2 {
  
  let x = (point.x - cx) / fx;
  let y = (point.y - cy) / fy;

  // Iterative undistortion refinement
  const maxIterations = 3;
  for (let i = 0; i < maxIterations; i++) {
    let r2 = x * x + y * y;
    let radialDistortion = 1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2 + k4 * r2 * r2 * r2 * r2;

    let xDistorted = x * radialDistortion;
    let yDistorted = y * radialDistortion;

    x -= (xDistorted - x);
    y -= (yDistorted - y);
  }

  return new Vector2(x * fx + cx, y * fy + cy);
}

export function undistortPointFisheye4Fov(
  point: Vector2, width: number, height: number, cx: number, cy: number,
  diagonalFovDegrees: number, k1: number, k2: number, k3: number, k4: number
): Vector2 {

  // Convert diagonal FOV to focal length
  const diagonalFovRadians = (diagonalFovDegrees * Math.PI) / 180;
  const imageDiagonal = Math.sqrt(width * width + height * height);
  const f = imageDiagonal / (2 * Math.tan(diagonalFovRadians / 2));

  // Compute fx and fy from focal length assuming aspect ratio correction
  const aspectRatio = width / height;
  const fx = f / Math.sqrt(1 + aspectRatio * aspectRatio);
  const fy = f / Math.sqrt(1 + (1 / aspectRatio) * (1 / aspectRatio));

  let x = (point.x - cx) / fx;
  let y = (point.y - cy) / fy;

  // Iterative undistortion refinement
  const maxIterations = 3;
  for (let i = 0; i < maxIterations; i++) {
    let r2 = x * x + y * y;
    let radialDistortion = 1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2 + k4 * r2 * r2 * r2 * r2;

    let xDistorted = x * radialDistortion;
    let yDistorted = y * radialDistortion;

    x -= (xDistorted - x);
    y -= (yDistorted - y);
  }

  return new Vector2(x * fx + cx, y * fy + cy);
}


export function imagePointToContainerPoint(
  imagePoint: Vector2,
  containerSize: Vector2,
  imageSize: Vector2
) {
  // Calculate the aspect ratios
  const containerAspectRatio = containerSize.x / containerSize.y;
  const imageAspectRatio = imageSize.x / imageSize.y;

  if (containerAspectRatio > imageAspectRatio) {
    // Container is wider than its image aspect ratio
    const scale = containerSize.y / imageSize.y;
    const totalWidth = imageSize.x * scale;
    const paddingX = (containerSize.x - totalWidth) / 2;
    return new Vector2(imagePoint.x * scale + paddingX, imagePoint.y * scale);
  } else {
    // Container is narrower than its image aspect ratio
    const scale = containerSize.x / imageSize.x;
    const totalHeight = imageSize.y * scale;
    const paddingY = (containerSize.y - totalHeight) / 2;
    return new Vector2(imagePoint.x * scale, imagePoint.y * scale + paddingY);
  }
}

export function containerPointToImagePoint(
  containerPoint: Vector2,
  containerSize: Vector2,
  imageSize: Vector2
) {
  // Calculate the aspect ratios
  const containerAspectRatio = containerSize.x / containerSize.y;
  const imageAspectRatio = imageSize.x / imageSize.y;

  if (containerAspectRatio > imageAspectRatio) {
    // container is wider than its image aspect ratio
    const scale = containerSize.y / imageSize.y;
    const totalWidth = imageSize.x * scale;
    const paddingX = (containerSize.x - totalWidth) / 2;
    return new Vector2((containerPoint.x - paddingX) / scale, containerPoint.y / scale);
  } else {
    // container is narrower than its image aspect ratio
    const scale = containerSize.x / imageSize.x;
    const totalHeight = imageSize.y * scale;
    const paddingY = (containerSize.y - totalHeight) / 2;
    return new Vector2(containerPoint.x / scale, (containerPoint.y - paddingY) / scale);
  }
}

export function transformVirtualViewToFisheye(
  virtualViewPoint: Vector2,           // Input normalized UV (0–1) from the fisheye image.
  fishEyeFov: number,              // Full capture fisheye FOV (in radians).
  fisheyeDistCoef: number[],       // Array of 4 distortion coefficients [k1, k2, k3, k4].
  virtualViewRoll: number,         // Virtual view roll (in radians).
  virtualViewPitch: number,        // Virtual view pitch (in radians).
  virtualViewFov: number,          // Virtual view vertical FOV (in radians).
  virtualViewAspect: number        // Virtual view aspect ratio (width/height).
): Vector2 {
  // 1. Convert the inbound fisheye UV (range [0,1]) into screen coordinates:
  //    screen = (vUV - 0.5)*2, with x scaled by virtualViewAspect.
  const screen = new Vector2(
    (virtualViewPoint.x - 0.5) * 2,
    (virtualViewPoint.y - 0.5) * 2
  );
  screen.x *= virtualViewAspect;

  // 2. Compute the focal length from the virtual view FOV.
  const f = 1.0 / Math.tan(virtualViewFov / 2);

  // 3. Build the initial ray.
  let ray = new Vector3(screen.x, screen.y, f);
  const rayLen = ray.length();
  if (rayLen !== 0) {
    ray.scaleInPlace(1 / rayLen);
  }

  // 4. Apply the virtual view’s rotations.
  const cosPitch = Math.cos(virtualViewPitch);
  const sinPitch = Math.sin(virtualViewPitch);
  // Rotation about X-axis (pitch)
  const pitchMatrix = Matrix.FromValues(
    1, 0, 0, 0,
    0, cosPitch, -sinPitch, 0,
    0, sinPitch, cosPitch, 0,
    0, 0, 0, 1
  );

  const cosRoll = Math.cos(virtualViewRoll);
  const sinRoll = Math.sin(virtualViewRoll);
  // Rotation about Z-axis (roll)
  const rollMatrix = Matrix.FromValues(
    cosRoll, -sinRoll, 0, 0,
    sinRoll, cosRoll, 0, 0,
    0, 0, 1, 0,
    0, 0, 0, 1
  );

  const transformMatrix = rollMatrix.multiply(pitchMatrix);
  let rotatedRay = Vector3.TransformCoordinates(ray, transformMatrix);
  const rotatedLen = rotatedRay.length();
  if (rotatedLen !== 0) {
    rotatedRay.scaleInPlace(1 / rotatedLen);
  }

  // 5. Compute theta, the angle between the rotated ray and the optical axis (+Z).
  const clampedZ = Math.max(-1, Math.min(1, rotatedRay.z));
  const theta = Math.acos(clampedZ);

  // If theta is outside the fisheye capture area, return an "invalid" value.
  if (theta > fishEyeFov / 2) {
    return new Vector2(-1, -1);
  }

  // 6. Compute phi (azimuth) from the rotated ray.
  const phi = Math.atan2(rotatedRay.y, rotatedRay.x);

  // 7. Map theta to a normalized radial coordinate.
  //    In the shader: r = (theta/(fishEyeFov/2))*0.5.
  const r = (theta / (fishEyeFov / 2)) * 0.5;

  // 8. Apply distortion correction using a 4-term polynomial:
  //    r_corr = r * (1 + k1*r^2 + k2*r^4 + k3*r^6 + k4*r^8).
  const r2 = r * r;
  const r4 = r2 * r2;
  const r6 = r4 * r2;
  const r8 = r4 * r4;
  const r_corr = r * (1 +
    fisheyeDistCoef[0] * r2 +
    fisheyeDistCoef[1] * r4 +
    fisheyeDistCoef[2] * r6 +
    fisheyeDistCoef[3] * r8
  );

  // 9. Compute the final virtual view UV coordinates.
  //    The shader does: texCoords = vec2(0.5 + r_corr*cos(phi), 0.5 + r_corr*sin(phi)).
  const u_out = 0.5 + r_corr * Math.cos(phi);
  const v_out = 0.5 + r_corr * Math.sin(phi);

  return new Vector2(u_out, v_out);
}


/**
 * Invert `transformVirtualViewToFisheye(...)`.
 *
 * Given a fisheye UV, find the corresponding virtual-view UV.
 *
 * @param fisheyePoint  The fisheye-space UV in [0..1].
 * @param fishEyeFov    Full fisheye FOV in radians (same as forward function).
 * @param fisheyeDistCoef The 4 distortion coefficients [k1, k2, k3, k4].
 * @param virtualViewRoll  The roll of the virtual view (radians).
 * @param virtualViewPitch The pitch of the virtual view (radians).
 * @param virtualViewFov   The vertical FOV of the virtual view (radians).
 * @param virtualViewAspect The aspect ratio (width/height) of the virtual view.
 * @returns The virtual-view UV in [0..1], or (-1, -1) if invalid/out of range.
 */
export function transformFisheyeToVirtualView(
  fisheyePoint: Vector2,
  fishEyeFov: number,
  fisheyeDistCoef: number[],
  virtualViewRoll: number,
  virtualViewPitch: number,
  virtualViewFov: number,
  virtualViewAspect: number
): Vector2 {
  // 1. Convert from fisheye UV [0..1] to a signed offset around (0.5, 0.5).
  const x = fisheyePoint.x - 0.5;
  const y = fisheyePoint.y - 0.5;
  
  // 2. Compute the distorted radius r_corr and azimuth phi.
  const r_corr = Math.sqrt(x * x + y * y);
  // If r_corr > 0.5, we're already outside the circular fisheye region
  // used by the forward function (which clamps at r_corr=0.5).
  if (r_corr > 0.5) {
    return new Vector2(-1, -1); // "invalid"
  }
  const phi = Math.atan2(y, x); // => forward used (cos(phi) for x, sin(phi) for y)
  
  // 3. Invert the radial distortion r_corr => r using a numerical solver.
  //    We want: r_corr = r * [1 + k1 r^2 + k2 r^4 + k3 r^6 + k4 r^8].
  const r = invertDistortion(r_corr, fisheyeDistCoef);
  if (r < 0.0) {
    // Something failed (e.g. no real solution in the allowed range).
    return new Vector2(-1, -1);
  }
  
  // 4. From r, compute theta = r * fishEyeFov. (Forward code => r = theta / fishEyeFov.)
  const theta = r * fishEyeFov;
  
  // If theta > fishEyeFov/2, then in the forward pass we would return invalid.
  if (theta > fishEyeFov * 0.5) {
    return new Vector2(-1, -1);
  }

  // 5. Rebuild the "rotatedRay" via spherical coords:
  //    rotatedRay = (sin(theta)*cos(phi), sin(theta)*sin(phi), cos(theta))
  const sinTheta = Math.sin(theta);
  const rotatedRay = new Vector3(
    sinTheta * Math.cos(phi),
    sinTheta * Math.sin(phi),
    Math.cos(theta)
  );

  // 6. Invert the forward rotation (which was rollMatrix * pitchMatrix).
  //    => The inverse is pitchMatrix^-1 * rollMatrix^-1, applied in reverse order:
  const inverseRotatedRay = applyInversePitchRoll(rotatedRay, virtualViewPitch, virtualViewRoll);

  // 7. The forward code did:
  //       ray = (screen.x, screen.y, f) and then normalized => length ~ sqrt(screen^2 + f^2).
  //    Here, `inverseRotatedRay` is the normalized direction in the virtual view. We must
  //    un-normalize it so that its Z becomes `f = 1 / tan(virtualViewFov / 2)`.
  const f = 1.0 / Math.tan(virtualViewFov * 0.5);
  if (Math.abs(inverseRotatedRay.z) < 1e-8) {
    return new Vector2(-1, -1);
  }
  const scale = f / inverseRotatedRay.z;
  const sx = inverseRotatedRay.x * scale;
  const sy = inverseRotatedRay.y * scale;

  // 8. Undo the "screen" step from the forward function.
  //    forward step was: screen.x = (vUV.x - 0.5)*2 * aspect
  //                      screen.y = (vUV.y - 0.5)*2
  //    so here we do the inverse:
  const screenXNoAspect = sx / virtualViewAspect;
  const screenY = sy; // no extra factor

  // 9. Map back to the [0..1] range for the virtualView UV.
  //    forward: screenXNoAspect = (vUV.x - 0.5) * 2
  //             => vUV.x = 0.5 + 0.5 * screenXNoAspect
  //         y  : (vUV.y - 0.5) * 2 = screenY
  //             => vUV.y = 0.5 + 0.5 * screenY
  const u = 0.5 + 0.5 * screenXNoAspect;
  const v = 0.5 + 0.5 * screenY;

  // Return if in valid range, otherwise mark invalid. (Optional clamp checks.)
  if (u < 0 || u > 1 || v < 0 || v > 1) {
    return new Vector2(-1, -1);
  }

  return new Vector2(u, v);
}

/**
 * Invert the radial distortion polynomial:
 *
 *     r_corr = r * (1 + k1 r^2 + k2 r^4 + k3 r^6 + k4 r^8).
 *
 * We find r given r_corr by simple binary search in [0..0.5].
 */
function invertDistortion(
  r_corr: number,
  coef: number[]
): number {
  // If r_corr == 0, then r == 0 trivially.
  if (r_corr < 1e-12) {
    return 0;
  }

  // A simple monotonic assumption: r cannot exceed 0.5 in the forward code,
  // so we do a binary search in [0, 0.5].
  // (If your lens has different bounds, adjust accordingly.)
  const maxR = 0.5;
  let left = 0;
  let right = maxR;
  let mid = 0;

  for (let i = 0; i < 30; i++) {
    mid = 0.5 * (left + right);
    const fMid = distortedRadius(mid, coef) - r_corr;
    if (fMid > 0) {
      right = mid;
    } else {
      left = mid;
    }
  }
  // mid is our best guess
  return mid;
}

/**
 * Apply the distortion polynomial:
 *    r_distorted = r * (1 + k1*r^2 + k2*r^4 + k3*r^6 + k4*r^8).
 */
function distortedRadius(r: number, coef: number[]): number {
  const r2 = r * r;
  const r4 = r2 * r2;
  const r6 = r4 * r2;
  const r8 = r4 * r4;
  return r * (1.0
    + coef[0] * r2
    + coef[1] * r4
    + coef[2] * r6
    + coef[3] * r8);
}

/**
 * Invert the combined roll-and-pitch rotation from the forward pass.
 *
 * Forward pass did:
 *     transformMatrix = rollMatrix * pitchMatrix;
 *     rotatedRay = transformMatrix * originalRay;
 *
 * => Inverse is:
 *     originalRay = pitchMatrix^-1 * rollMatrix^-1 * rotatedRay.
 *
 * For small transforms, we can simply build the inverse matrices or
 * directly apply the negative angles. Here, we demonstrate the
 * "apply negative angles" approach for pitch, then roll.
 */
function applyInversePitchRoll(
  rotatedRay: Vector3,
  pitch: number,
  roll: number
): Vector3 {
  // We'll do "unroll" first, then "unpitch", because forward was pitch THEN roll:
  //   rotatedRay = rollMatrix( pitchMatrix( ray ) ).

  // Inverse roll around Z by -roll:
  const cosR = Math.cos(roll);
  const sinR = Math.sin(roll);
  let tmpX = cosR * rotatedRay.x + sinR * rotatedRay.y;
  let tmpY = -sinR * rotatedRay.x + cosR * rotatedRay.y;
  let tmpZ = rotatedRay.z; // unaffected

  // Inverse pitch around X by -pitch:
  const cosP = Math.cos(pitch);
  const sinP = Math.sin(pitch);
  const outX = tmpX;
  const outY = cosP * tmpY + sinP * tmpZ;
  const outZ = -sinP * tmpY + cosP * tmpZ;

  return new Vector3(outX, outY, outZ);
}

