import env from "../../../env";
import { getImageDataFromMedia, tryOrLog, wrappedFetch } from "../../../utils";
import createTFLiteModule from "./createTFLiteModule";
import createTFLiteSIMDModule from "./createTFLiteSIMDModule";

const modelPath = `${env.PUBLIC_URL}/models/256x256/selfiesegmentation_mlkit-256x256-2021_01_19-v1215.f16.tflite`;

const defaultModelParams = {
  processWidth: 640,
  processHeight: 480,
  kernelSize: 0.0,
  useSoftmax: true,
  usePadding: false,
  threshold: 0.1,
  interpolation: 1,
};

export class TfLiteManager {
  #wasm = null;
  #predictions = [];
  #params = { ...defaultModelParams };
  initialized = false;

  async init(predictParams = {}) {
    this.initialized = false;
    let wasm = null;
    console.log("Initializing TFLite...");
    await tryOrLog(async () => {
      wasm = await createTFLiteSIMDModule();
    }, true);
    if (!wasm)
      await tryOrLog(async () => {
        console.warn("Fallback to regular tflite");
        wasm = await createTFLiteModule();
      }, true);
    if (!wasm) throw new Error("TfLite could not be initialized");

    console.log("Loading background tensorflow model...");
    const modelResponse = await wrappedFetch(modelPath);
    const model = await modelResponse.arrayBuffer();
    const modelBufferOffset = wasm._getModelBufferMemoryOffset();
    wasm.HEAPU8.set(new Uint8Array(model), modelBufferOffset);
    const modelLoadedStatus = wasm._loadModel(model.byteLength);
    if (modelLoadedStatus !== 0) throw new Error("Model load failed");
    console.log("Tensorflow background model loaded");

    this.#params = {
      ...this.#params,
      ...predictParams,
    };
    this.#wasm = wasm;
    this.initialized = true;
  }

  async predict(media) {
    if (this.initialized) {
      const params = this.#params;
      const wasm = this.#wasm;

      wasm._setKernelSize(params.kernelSize);
      wasm._setUseSoftmax(params.useSoftmax ? 1 : 0);
      wasm._setUsePadding(params.usePadding ? 1 : 0);
      wasm._setThresholdWithoutSoftmax(params.threshold);
      wasm._setInterpolation(params.interpolation);

      const mediaImageData = getImageDataFromMedia(media, {
        width: params.processWidth,
        height: params.processHeight,
      });
      const inputImageBufferOffset = wasm._getInputImageBufferOffset();
      wasm.HEAPU8.set(mediaImageData.data, inputImageBufferOffset);
      wasm._exec(params.processWidth, params.processHeight);

      const outputLength = params.processWidth * params.processHeight;
      const outputImageBufferOffset = wasm._getOutputImageBufferOffset();
      const predictionsArray = [];
      for (let i = 0; i < outputLength; i++) {
        predictionsArray[i] = wasm.HEAPU8[outputImageBufferOffset + i * 4 + 3];
      }
      this.#predictions = predictionsArray;
    }
    return this.#predictions;
  }
}

export default TfLiteManager;
