import convnetjs from 'convnetjs'

const originalImage = new Image()
export const networkImage = new Image()
export let networkImageData: string = ''

const BATCHES_PER_INTERATION = 100
const MOD_SKIP_DRAW = 10
let imageContext: CanvasRenderingContext2D | null
let networkContext: CanvasRenderingContext2D | null
let imageData: any
const DRAW_SIZE = 100 // size of our drawing area
let counter = 0
let interval: ReturnType<typeof setInterval>

const layerDefinitions = [
  { type: 'input', out_sx: 1, out_sy: 1, out_depth: 2 }, // 2 inputs: x, y
  { type: 'fc', num_neurons: 20, activation: 'relu' },
  { type: 'fc', num_neurons: 20, activation: 'relu' },
  /*
  { type: 'fc', num_neurons: 20, activation: 'relu' },
  { type: 'fc', num_neurons: 20, activation: 'relu' },
  { type: 'fc', num_neurons: 20, activation: 'relu' },
  { type: 'fc', num_neurons: 20, activation: 'relu' },
  { type: 'fc', num_neurons: 20, activation: 'relu' },*/
  { type: 'regression', num_neurons: 3 },
]

const network = new convnetjs.Net()
network.makeLayers(layerDefinitions)

const trainer = new convnetjs.SGDTrainer(network, {
  learning_rate: Math.random() * 0.04 + 0.005,
  momentum: 0.5 + Math.random() * 0.4,
  batch_size: 5,
  l2_decay: 0.0,
})

const networkCanvas = document.createElement('canvas')
const imageCanvas = document.createElement('canvas')

const update = () => {
  var p = imageData.data

  var v = new convnetjs.Vol(1, 1, 2)

  const trainerBatchSize = trainer.batch_size
  for (var iters = 0; iters < trainerBatchSize; iters++) {
    for (var i = 0; i < BATCHES_PER_INTERATION; i++) {
      var x = convnetjs.randi(0, DRAW_SIZE)
      var y = convnetjs.randi(0, DRAW_SIZE)
      var ix = (DRAW_SIZE * y + x) * 4
      var r = [p[ix] / 255.0, p[ix + 1] / 255.0, p[ix + 2] / 255.0]
      v.w[0] = (x - DRAW_SIZE / 2) / DRAW_SIZE
      v.w[1] = (y - DRAW_SIZE / 2) / DRAW_SIZE
      trainer.train(v, r)
    }
  }
}

const draw = () => {
  if (!networkContext || counter % MOD_SKIP_DRAW !== 0) return

  var imageData = networkContext.getImageData(0, 0, DRAW_SIZE, DRAW_SIZE)
  var v = new convnetjs.Vol(1, 1, 2)
  for (let x = 0; x < DRAW_SIZE; x++) {
    v.w[0] = (x - DRAW_SIZE / 2) / DRAW_SIZE
    for (var y = 0; y < DRAW_SIZE; y++) {
      v.w[1] = (y - DRAW_SIZE / 2) / DRAW_SIZE

      var ix = (DRAW_SIZE * y + x) * 4
      var r = network.forward(v)
      imageData.data[ix + 0] = ~~(255 * r.w[0])
      imageData.data[ix + 1] = ~~(255 * r.w[1])
      imageData.data[ix + 2] = ~~(255 * r.w[2])
      imageData.data[ix + 3] = 255
    }
  }
  networkContext.putImageData(imageData, 0, 0)
  networkImageData = networkCanvas.toDataURL()
  // networkImage.src = networkImageData
}

const tick = () => {
  update()
  draw()
}

originalImage.onload = () => {
  imageCanvas.width = DRAW_SIZE
  imageCanvas.height = DRAW_SIZE
  networkCanvas.width = DRAW_SIZE
  networkCanvas.height = DRAW_SIZE

  imageContext = imageCanvas.getContext('2d')
  networkContext = networkCanvas.getContext('2d')
  if (!imageContext) {
    throw new Error('No context')
  }
  imageContext.drawImage(originalImage, 0, 0, DRAW_SIZE, DRAW_SIZE)
  imageData = imageContext.getImageData(0, 0, DRAW_SIZE, DRAW_SIZE) // grab the data pointer. Our dataset.

  // start the regression!
  interval = setInterval(tick, 1)
}

export const loadImage = (url: string) => {
  clearInterval(interval)
  counter = 0
  originalImage.src = url
}
