import npyjs from 'npyjs';
import { env, InferenceSession, Tensor } from 'onnxruntime-web';
import ort from 'onnxruntime-web';

import { CanvasUtils } from '@/apps/mockup/classes/CanvasUtils';
import { SolidColor } from '@/color/classes/SolidColor';

env.wasm.wasmPaths = {
	'ort-wasm-simd.wasm': 'https://cdn-om.cdnpk.net/ia/sam/ort-wasm-simd-v2.wasm?v1=true',
};

/**
 * Interface for model scaling properties.
 */
export interface modelScaleProps {
	samScale: number;
	height: number;
	width: number;
}

/**
 * Singleton class for handling SAM model operations.
 */
export class SAMModel {
	private static instance: SAMModel;
	private model?: InferenceSession;
	private tensor?: Tensor;
	private modelScale?: modelScaleProps;

	private mask: HTMLCanvasElement = document.createElement('canvas');

	/**
	 * Gets the singleton instance of the SAMModel class.
	 * @returns The singleton instance of SAMModel.
	 */
	public static getInstance(): SAMModel {
		if (!SAMModel.instance) {
			SAMModel.instance = new SAMModel();
		}
		return SAMModel.instance;
	}

	/**
	 * Initializes the ONNX model by loading it from a specified directory.
	 */
	public initModel = async () => {
		try {
			const URL = 'https://cdn-om.cdnpk.net/ia/sam/sam_vit_b_decoder.onnx.gz';
			this.model = await InferenceSession.create(URL);
		} catch (error) {
			console.log(error);
		}
	};

	/**
	 * Builds a mask from given data and updates the mask canvas.
	 * @param payload The input data for the mask.
	 * @param payload.input The input data for the mask.
	 * @param payload.width The width of the mask
	 * @param payload.height The height of the mask
	 * @returns The SAMModel instance for chaining.
	 */
	private buildMask(payload: { input: any; width: number; height: number}) {
		this.mask = CanvasUtils.fromImageData(CanvasUtils.arrayToImageData(payload), this.mask);
	}

	/**
	 * Calculates the scale for the input image based on a predefined size.
	 * @param image The HTMLImageElement to scale.
	 * @returns The scale properties for the model.
	 */
	private handleImageScale = (canvas: HTMLCanvasElement) => {
		const LONG_SIDE_LENGTH = 1024;
		const w = canvas.width;
		const h = canvas.height;
		const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
		return { height: h, width: w, samScale };
	};

	/**
	 * Prepares the data for the model based on input clicks and scales.
	 * @param tensor The input tensor for the model.
	 * @param clicks An array of click points.
	 * @param modelScale The scale properties for the model.
	 * @returns The prepared data for model inference.
	 */
	private getModelData(tensor: Tensor, clicks: { x: number; y: number }[], modelScale: modelScaleProps) {
		const imageEmbedding = tensor;
		let pointCoords;
		let pointLabels;
		let pointCoordsTensor;
		let pointLabelsTensor;

		// Check there are input click prompts
		if (clicks) {
			const n = clicks.length;

			// If there is no box input, a single padding point with
			// label -1 and coordinates (0.0, 0.0) should be concatenated
			// so initialize the array to support (n + 1) points.
			pointCoords = new Float32Array(2 * (n + 1));
			pointLabels = new Float32Array(n + 1);

			// Add clicks and scale to what SAM expects
			for (let i = 0; i < n; i++) {
				pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
				pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
				pointLabels[i] = 1; // Meta implementation have this clicks[i].clickType, but never change value from 1
			}

			// Add in the extra point/label when only clicks and no box
			// The extra point is at (0, 0) with label -1
			pointCoords[2 * n] = 0.0;
			pointCoords[2 * n + 1] = 0.0;
			pointLabels[n] = -1.0;

			// Create the tensor
			pointCoordsTensor = new Tensor('float32', pointCoords, [1, n + 1, 2]);
			pointLabelsTensor = new Tensor('float32', pointLabels, [1, n + 1]);
		}
		const imageSizeTensor = new Tensor('float32', [this.modelScale!.height, this.modelScale!.width]);
		if (pointCoordsTensor === undefined || pointLabelsTensor === undefined) return;

		// There is no previous mask, so default to an empty tensor
		const maskInput = new Tensor('float32', new Float32Array(256 * 256), [1, 1, 256, 256]);
		// There is no previous mask, so default to 0
		const hasMaskInput = new Tensor('float32', [0]);

		return {
			image_embeddings: imageEmbedding,
			point_coords: pointCoordsTensor,
			point_labels: pointLabelsTensor,
			orig_im_size: imageSizeTensor,
			mask_input: maskInput,
			has_mask_input: hasMaskInput,
		};
	}

	/**
	 * Loads a tensor from .npy file data.
	 * @param data The Float32Array data from the .npy file.
	 * @returns The SAMModel instance for chaining.
	 */
	public async loadNpyTensor(urlFile: string): Promise<SAMModel> {
		const npLoader = new npyjs();
		const npArray = await npLoader.load(urlFile);
		this.tensor = new ort.Tensor('float32', npArray.data, npArray.shape);
		return this;
	}

	public async loadModelScale(canvas: HTMLCanvasElement) {
		this.modelScale = this.handleImageScale(canvas);
		return this;
	}

	/**
	 * Runs the model with given input point and image, producing a mask.
	 * @param point The input point for the model.
	 */
	public async runModel(point: { x: number; y: number }) {
		const init = performance.now();
		if (!this.tensor || !this.model) throw new Error('Model or Tensor not initialized');

		const feeds = this.getModelData(this.tensor, [point], this.modelScale!);
		if (!feeds) throw new Error('Error extracting Model Data');

		const results = await this.model.run(feeds);
		const output = results[this.model.outputNames[0]];
		this.buildMask({
			input: output.data,
			width: output.dims[2],
			height: output.dims[3],
		});
		const end = performance.now();
		console.log('Time to run model: ', end - init + 'ms');
	}

	/**
	 * Returns the initialized model.
	 * @returns The InferenceSession model.
	 */
	public getModel(): InferenceSession {
		return this.model as InferenceSession;
	}

	get hasTensor(): boolean {
		return this.tensor !== undefined;
	}

	get maskCanvas() {
		return this.mask;
	}
}

export const SAMModelInstance = SAMModel.getInstance();
