/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.forward.standard;

import deepboof.BaseTensor;
import deepboof.Function;
import deepboof.Tensor;
import deepboof.misc.TensorOps;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public abstract class BaseFunction<T extends Tensor>
implements Function<T> {
    protected int[] shapeInput = new int[0];
    protected List<int[]> shapeParameters = new ArrayList<int[]>();
    protected int[] shapeOutput = new int[0];
    protected List<T> parameters;
    protected int miniBatchSize;

    @Override
    public void initialize(int ... shapeInput) {
        this.shapeInput = (int[])shapeInput.clone();
        this.shapeParameters.clear();
        Arrays.fill(this.shapeOutput, -1);
        this._initialize();
    }

    public abstract void _initialize();

    @Override
    public void setParameters(List<T> parameters) {
        TensorOps.checkShape("parameters", this.shapeParameters, parameters, false);
        this.parameters = new ArrayList(parameters);
        this._setParameters(parameters);
    }

    public abstract void _setParameters(List<T> var1);

    @Override
    public List<T> getParameters() {
        return this.parameters;
    }

    @Override
    public void forward(T input, T output) {
        if (this.shapeInput == null) {
            throw new IllegalArgumentException("Must initialize first!");
        }
        TensorOps.checkShape("input", -1, this.shapeInput, ((BaseTensor)input).getShape(), true);
        TensorOps.checkShape("output", -1, this.shapeOutput, ((BaseTensor)output).getShape(), true);
        this.miniBatchSize = ((Tensor)input).length(0);
        if (((Tensor)output).length(0) != this.miniBatchSize) {
            int M = ((Tensor)output).length(0);
            throw new IllegalArgumentException("Dimension 0 in the output is " + M + " and does not match input dimension 0 of " + this.miniBatchSize);
        }
        this._forward(input, output);
    }

    public abstract void _forward(T var1, T var2);

    @Override
    public List<int[]> getParameterShapes() {
        return this.shapeParameters;
    }

    @Override
    public int[] getOutputShape() {
        return this.shapeOutput;
    }
}

