/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.forward.standard;

import deepboof.forward.FunctionLinear;
import deepboof.impl.forward.standard.BaseFunction;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

public class FunctionLinear_F64
extends BaseFunction<Tensor_F64>
implements FunctionLinear<Tensor_F64> {
    protected int D;
    protected int M;
    Tensor_F64 weight;
    Tensor_F64 bias;

    public FunctionLinear_F64(int numberOfOutputs) {
        this.M = numberOfOutputs;
    }

    @Override
    public void _forward(Tensor_F64 input, Tensor_F64 output) {
        FunctionLinear_F64.forwards(input, output, this.weight, this.bias, this.miniBatchSize, this.D, this.M);
    }

    public static void forwards(Tensor_F64 input, Tensor_F64 output, Tensor_F64 weight, Tensor_F64 bias, int miniBatchSize, int D2, int M) {
        for (int stack = 0; stack < miniBatchSize; ++stack) {
            int indexStartIn = stack * D2 + input.startIndex;
            for (int outputElement = 0; outputElement < M; ++outputElement) {
                int indexW = outputElement * D2 + weight.startIndex;
                double b = bias.d[outputElement + bias.startIndex];
                int indexIn = indexStartIn;
                int end = indexIn + D2;
                double sum = 0.0;
                while (indexIn < end) {
                    sum += input.d[indexIn++] * weight.d[indexW++];
                }
                int indexOut = stack * M + outputElement + output.startIndex;
                output.d[indexOut] = sum + b;
            }
        }
    }

    @Override
    public void _initialize() {
        if (this.shapeInput.length < 1) {
            throw new IllegalArgumentException("Input tensor shape must have a dimension of at least 1");
        }
        this.D = TensorOps.tensorLength(this.shapeInput);
        this.shapeParameters.add(new int[]{this.M, this.D});
        this.shapeParameters.add(new int[]{this.M});
        this.shapeOutput = new int[]{this.M};
    }

    @Override
    public void _setParameters(List<Tensor_F64> parameters) {
        this.weight = parameters.get(0);
        this.bias = parameters.get(1);
    }

    @Override
    public int getNumberOfOutputs() {
        return this.D;
    }

    @Override
    public Class<Tensor_F64> getTensorType() {
        return Tensor_F64.class;
    }
}

