import * as tf from '@tensorflow/tfjs';
import * as _CV from 'mirada/dist/src/types/opencv';
import * as ort from 'onnxruntime-web';
import { simd } from 'wasm-feature-detect';

declare let cv: typeof _CV;

// Safari doesn't support using webgl writing to canvas from a worker thread, yet.
// Using the CPU backend seems fast enough given our little we're holding tensorflow in here,
// so we'll use it globally to avoid the hassle of browser detection and passing it around.
tf.setBackend('cpu');

// This is the raw type from the onnx inference session.
// We're mapping it here to a more usable type.
export type OnnxTensor =
  | number
  | number[]
  | number[][]
  | number[][][]
  | number[][][][]
  | number[][][][][]
  | number[][][][][][];

export enum ScanMessageType {
  Prepare = 0,
  Process = 1,
}

type OriginalImageMatrix = {
  data: Uint8Array;
  width: number;
  height: number;
};

export type ScanDataPrepare = {
  message: {
    type: ScanMessageType.Prepare;
    image: ImageData;
  };
  response: {
    original: OriginalImageMatrix;
    tensor: {
      data: Float32Array | Int32Array | Uint8Array;
      shape: number[];
    };
  };
};

export type ScanDataProcess = {
  message: {
    type: ScanMessageType.Process;
    original: OriginalImageMatrix;
    tensor: OnnxTensor;
  };
  response: {
    imageData: ImageData;
  };
};

export type ScanMessage = ScanDataPrepare | ScanDataProcess;

export async function wasmSupport() {
  const wasm = (() => {
    try {
      if (typeof WebAssembly === 'object' && typeof WebAssembly.instantiate === 'function') {
        const module = new WebAssembly.Module(
          Uint8Array.of(0x0, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00)
        );
        if (module instanceof WebAssembly.Module) {
          return new WebAssembly.Instance(module) instanceof WebAssembly.Instance;
        }
      }
      // eslint-disable-next-line no-empty
    } catch (e) {}
    return false;
  })();

  return {
    simd: await simd(),
    wasm,
  };
}
export async function prepareImageData(rawImage: ImageData): Promise<ScanDataPrepare['response']> {
  const original = cv.matFromImageData(rawImage);

  const rawImg = new cv.Mat();
  cv.resize(original, rawImg, new cv.Size(384, 384), 0, 0, cv.INTER_LINEAR);
  cv.cvtColor(rawImg, rawImg, cv.COLOR_RGB2BGR);
  cv.transpose(rawImg, rawImg);

  // Take an input shape of [384,384,3] and convert it to [1,3,384,384], while also normalizing the values
  //
  // We use TensorflowJS here because ONNX doesn't support operators unless they're baked in to the model,
  // and the model we're using isn't something we have that kind of control over (it's from a third party).
  //
  // Medium term we'll subclass the model and add the necessary operators to it, but for now this is fine.
  const tensor = tf.tensor(rawImg.data, [384, 384, 3]).transpose().div(255).expandDims(0);

  return {
    original: {
      data: original.data,
      width: original.size().width,
      height: original.size().height,
    },
    tensor: {
      data: await tensor.data(),
      shape: tensor.shape,
    },
  };
}

export async function runModelInference(tensor: tf.Tensor, session: ort.InferenceSession) {
  const inputTensor = new ort.Tensor(
    'float32',
    new Float32Array(await tensor.data()),
    [1, 3, 384, 384]
  );

  const feeds = { image: inputTensor };

  const result = await session.run(feeds);

  // Again, we use TensorflowJS here because ONNX doesn't support operators unless they're baked in to the model.
  // And without access to something like numpy in python it's a pain to do ndarray operations.
  const asdfadsf = tf
    .tensor(result.out.data as Float32Array, [1, 2, 384, 384])
    .argMax(1)
    .array();

  console.log(await asdfadsf);

  return asdfadsf;
}

export async function processModelOutput(data: ScanDataProcess['message']) {
  try {
    const original = cv.matFromArray(
      data.original.height,
      data.original.width,
      cv.CV_8UC4,
      data.original.data
    );

    const IMAGE_SIZE = 384;

    const SCALE_X = original.size().width / IMAGE_SIZE;
    const SCALE_Y = original.size().height / IMAGE_SIZE;

    // The model outputs a matrix with the 0th dimension containing a 2d matrix of "background" as 0 and "foreground" as 1.
    // It'as an ndarray of 8-bit unisgned integers.
    // We need to convert this data to a black and white image to run edge detection on it,
    // before warping it. So we instantiate an empty mask of the correct size, and then we'll fill it in with the model output data.
    const mask = cv.Mat.zeros(IMAGE_SIZE * 2, IMAGE_SIZE * 2, cv.CV_8U);

    // We need to convert the mask to RGBA so we can manipulate the color of the image, based on the model output.
    cv.cvtColor(mask, mask, cv.COLOR_RGB2RGBA);

    // If we were running this in python we'd use numpy and one-line this. But we're not, so we'll handle the matrix manually.
    for (let r = 0; r < IMAGE_SIZE; r++) {
      // The model outputs it's predictions in the 0th dimension of the output matrix, so we grab our rows from there.
      const row: number[] = data.tensor[0][r];

      for (let c = 0; c < IMAGE_SIZE; c++) {
        let pixel = 0;
        if (row !== undefined) {
          pixel = row[c];
        }

        const x = IMAGE_SIZE / 2 + r;
        const y = IMAGE_SIZE / 2 + c;

        // The model outputs "background" as 0, we want to represent this pixel as an RGBA value of 0,0,0,255 (black)
        if (pixel === 0) {
          mask.ucharPtr(x, y)[0] = 0;
          mask.ucharPtr(x, y)[1] = 0;
          mask.ucharPtr(x, y)[2] = 0;
          mask.ucharPtr(x, y)[3] = 255;
          // Otherwise, we want to represent this pixel as an RGBA value of 255,255,255,255 (white)
        } else {
          mask.ucharPtr(x, y)[0] = 255;
          mask.ucharPtr(x, y)[1] = 255;
          mask.ucharPtr(x, y)[2] = 255;
          mask.ucharPtr(x, y)[3] = 255;
        }
      }
    }

    // Next we need to run edge detection on the mask, so we can find the corners of the document.

    // We run canny and dilate to manipulate the mask itself
    const canny = new cv.Mat();
    cv.Canny(mask, canny, 225, 255);

    const dilate = new cv.Mat(mask.size(), mask.type());
    cv.dilate(canny, dilate, cv.getStructuringElement(cv.MORPH_ELLIPSE, new cv.Size(5, 5)));

    // Then, we run contours with the appropriate parameters to shapes in the image (our foreground)
    // but only if the set of points results in a completely closed polygon.
    const contours = new cv.MatVector();
    const hierarchy = new cv.Mat();
    cv.findContours(dilate, contours, hierarchy, cv.RETR_LIST, cv.CHAIN_APPROX_NONE);

    // After detecting all the various shapes, we need to find the largest one, which is likely the document.
    const shapes = [];
    // @ts-expect-error types are wrong
    for (let i = 0; i < contours.size(); i++) {
      // @ts-expect-error FIXME
      shapes.push(contours.get(i));
    }
    const shape = shapes.sort((a, b) => {
      return cv.contourArea(b) - cv.contourArea(a);
    })[0];

    // Now that we have the largest shape, we can finally detect the corners of the document.
    // The exact details of what these two functions do eludes me. But the end state is a matrix of 4 points,
    // which are the corners of the document.
    const epsilon = 0.04 * cv.arcLength(shape, true);
    const corners = new cv.Mat(mask.size(), mask.type());
    cv.approxPolyDP(new cv.Mat(shape), corners, epsilon, true);

    // Once we have the corners from the results of the mask post-processing, we need to scale them back up to the original image size.
    for (let i = 0; i < corners.size().height; i++) {
      corners.data32S[i * 2] -= IMAGE_SIZE / 2;
      corners.data32S[i * 2 + 1] -= IMAGE_SIZE / 2;

      corners.data32S[i * 2] *= SCALE_X;
      corners.data32S[i * 2 + 1] *= SCALE_Y;
    }

    // Another annoyance of the JS version of opencv, we don't get back proper Point objects. We're just holding a vector of x/y coordinates.
    // Inexplicably it's also the only data that's int32, so we construct our own Point objects with the int32 representation of each point in the matrix
    const corner1 = new cv.Point(corners.data32S[0], corners.data32S[1]);
    const corner2 = new cv.Point(corners.data32S[2], corners.data32S[3]);
    const corner3 = new cv.Point(corners.data32S[4], corners.data32S[5]);
    const corner4 = new cv.Point(corners.data32S[6], corners.data32S[7]);

    // Next we order the corners, which allows us to determine the top-left, top-right, bottom-left, and bottom-right corners.
    const cornerArray = [
      { corner: corner1 },
      { corner: corner2 },
      { corner: corner3 },
      { corner: corner4 },
    ];

    // Sort by Y position (to get top-down)
    cornerArray
      .sort((item1, item2) => {
        // eslint-disable-next-line no-nested-ternary
        return item1.corner.y < item2.corner.y ? -1 : item1.corner.y > item2.corner.y ? 1 : 0;
      })
      .slice(0, 5);

    // Determine left/right based on x position of top and bottom 2
    const tl = cornerArray[0].corner.x < cornerArray[1].corner.x ? cornerArray[0] : cornerArray[1];
    const tr = cornerArray[0].corner.x > cornerArray[1].corner.x ? cornerArray[0] : cornerArray[1];
    const bl = cornerArray[2].corner.x < cornerArray[3].corner.x ? cornerArray[2] : cornerArray[3];
    const br = cornerArray[2].corner.x > cornerArray[3].corner.x ? cornerArray[2] : cornerArray[3];

    // We eventually want to warp the detected document shape to a squared off rectangle. To do that, we need to know the width and height of the rectangle.
    const widthBottom = Math.hypot(br.corner.x - bl.corner.x, br.corner.y - bl.corner.y);
    const widthTop = Math.hypot(tr.corner.x - tl.corner.x, tr.corner.y - tl.corner.y);

    const heightRight = Math.hypot(tr.corner.x - br.corner.x, tr.corner.y - br.corner.y);
    const heightLeft = Math.hypot(tl.corner.x - bl.corner.x, tr.corner.y - bl.corner.y);

    const finalWidth = widthBottom > widthTop ? widthBottom : widthTop;
    const finalHeight = heightRight > heightLeft ? heightRight : heightLeft;

    // We now have all the information we need to warp the image to a rectangle.
    // Compute the coordinates we'd _like_ to warp to.
    const finalDestCoords = cv.matFromArray(4, 1, cv.CV_32FC2, [
      0,
      0,
      finalWidth - 1,
      0,
      finalWidth - 1,
      finalHeight - 1,
      0,
      finalHeight - 1,
    ]);
    const srcCoords = cv.matFromArray(4, 1, cv.CV_32FC2, [
      tl.corner.x,
      tl.corner.y,
      tr.corner.x,
      tr.corner.y,
      br.corner.x,
      br.corner.y,
      bl.corner.x,
      bl.corner.y,
    ]);

    // With that information, we can compute the perspective transform matrix,
    // which we'll use in a bit to perform a warp operation on the image that was uploaded by the user.
    const M = cv.getPerspectiveTransform(srcCoords, finalDestCoords);

    // Apply the warp transformation to the image
    const final = new cv.Mat();
    cv.warpPerspective(
      original,
      final,
      M,
      new cv.Size(finalWidth, finalHeight),
      cv.INTER_LINEAR,
      cv.BORDER_CONSTANT,
      new cv.Scalar()
    );

    // Lastly we'll apply a threshold to the image to make it as nice a black and white as possible.
    cv.cvtColor(final, final, cv.COLOR_RGBA2GRAY);

    cv.adaptiveThreshold(
      final,
      final,
      255,
      cv.ADAPTIVE_THRESH_GAUSSIAN_C,
      cv.THRESH_BINARY,
      21,
      15
    );

    cv.cvtColor(final, final, cv.COLOR_GRAY2RGBA);

    // Finally, we can take the resulting matrix, convert it to a uint8array, create an ImageData object, and send it back to the main thread.
    return new ImageData(
      new Uint8ClampedArray(final.data),
      final.size().width,
      final.size().height
    );
  } catch (e) {
    printError(e);
  }
}

// Taken from here https://github.com/opencv-js/utils/blob/master/utils.js
function printError(err) {
  if (typeof err === 'undefined') {
    err = '';
  } else if (typeof err === 'number') {
    if (!isNaN(err)) {
      if (typeof cv !== 'undefined') {
        err = 'Exception: ' + cv.exceptionFromPtr(err).msg;
      }
    }
  } else if (typeof err === 'string') {
    const ptr = Number(err.split(' ')[0]);
    if (!isNaN(ptr)) {
      if (typeof cv !== 'undefined') {
        err = 'Exception: ' + cv.exceptionFromPtr(ptr).msg;
      }
    }
  }

  console.error(err);
}
