/*
 * Decompiled with CFR 0.152.
 */
package boofcv.deepboof;

import boofcv.alg.misc.GPixelMath;
import boofcv.deepboof.BaseImageClassifier;
import boofcv.deepboof.DataManipulationOps;
import boofcv.struct.image.GrayF32;
import boofcv.struct.image.Planar;
import deepboof.io.torch7.ConvertTorchToBoofForward;
import deepboof.io.torch7.ParseAsciiTorch7;
import deepboof.io.torch7.ParseBinaryTorch7;
import deepboof.io.torch7.SequenceAndParameters;
import deepboof.io.torch7.struct.TorchGeneric;
import deepboof.io.torch7.struct.TorchList;
import deepboof.io.torch7.struct.TorchNumber;
import deepboof.io.torch7.struct.TorchObject;
import deepboof.io.torch7.struct.TorchString;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F32;
import java.io.File;
import java.io.IOException;
import java.util.List;

public class ImageClassifierNiNImageNet
extends BaseImageClassifier {
    float[] mean;
    float[] stdev;
    static final int imageCrop = 224;
    Planar<GrayF32> imageBgr = new Planar<GrayF32>(GrayF32.class, 224, 224, 3);

    public ImageClassifierNiNImageNet() {
        super(224);
    }

    @Override
    public void loadModel(File directory) throws IOException {
        List<TorchObject> list = new ParseBinaryTorch7().parse(new File(directory, "nin_bn_final.t7"));
        TorchGeneric torchSequence = (TorchGeneric)((TorchGeneric)list.get(0)).get("model");
        TorchGeneric torchNorm = (TorchGeneric)torchSequence.get("transform");
        this.mean = this.torchListToArray((TorchList)torchNorm.get("mean"));
        this.stdev = this.torchListToArray((TorchList)torchNorm.get("std"));
        SequenceAndParameters seqparam = (SequenceAndParameters)ConvertTorchToBoofForward.convert(torchSequence);
        this.network = seqparam.createForward(3, 224, 224);
        this.tensorOutput = new Tensor_F32(TensorOps.WI(1, this.network.getOutputShape()));
        TorchList torchCategories = (TorchList)new ParseAsciiTorch7().parse(new File(directory, "synset.t7")).get(0);
        this.categories.clear();
        for (int i = 0; i < torchCategories.list.size(); ++i) {
            this.categories.add(((TorchString)torchCategories.list.get((int)i)).message);
        }
    }

    private float[] torchListToArray(TorchList torch) {
        float[] ret = new float[torch.list.size()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = (float)((TorchNumber)torch.list.get((int)i)).value;
        }
        return ret;
    }

    @Override
    protected Planar<GrayF32> preprocess(Planar<GrayF32> image) {
        super.preprocess(image);
        ((GrayF32[])this.imageBgr.bands)[0] = ((GrayF32[])this.imageRgb.bands)[2];
        ((GrayF32[])this.imageBgr.bands)[1] = ((GrayF32[])this.imageRgb.bands)[1];
        ((GrayF32[])this.imageBgr.bands)[2] = ((GrayF32[])this.imageRgb.bands)[0];
        GPixelMath.divide(this.imageBgr, 255.0, this.imageBgr);
        for (int band = 0; band < 3; ++band) {
            DataManipulationOps.normalize(this.imageBgr.getBand(band), this.mean[band], this.stdev[band]);
        }
        return this.imageBgr;
    }
}

