/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.coref.neural;

import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.List;
import java.util.stream.Collectors;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class NeuralCorefModel
implements Serializable {
    private static final long serialVersionUID = 2139427931784505653L;
    private SimpleMatrix antecedentMatrix;
    private SimpleMatrix anaphorMatrix;
    private SimpleMatrix pairFeaturesMatrix;
    private SimpleMatrix pairwiseFirstLayerBias;
    private List<SimpleMatrix> anaphoricityModel;
    private List<SimpleMatrix> pairwiseModel;
    private Embedding wordEmbeddings;

    public NeuralCorefModel(SimpleMatrix antecedentMatrix, SimpleMatrix anaphorMatrix, SimpleMatrix pairFeaturesMatrix, SimpleMatrix pairwiseFirstLayerBias, List<SimpleMatrix> anaphoricityModel, List<SimpleMatrix> pairwiseModel, Embedding wordEmbeddings) {
        this.antecedentMatrix = antecedentMatrix;
        this.anaphorMatrix = anaphorMatrix;
        this.pairFeaturesMatrix = pairFeaturesMatrix;
        this.pairwiseFirstLayerBias = pairwiseFirstLayerBias;
        this.anaphoricityModel = anaphoricityModel;
        this.pairwiseModel = pairwiseModel;
        this.wordEmbeddings = wordEmbeddings;
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.antecedentMatrix = new SimpleMatrix(this.antecedentMatrix);
        this.anaphorMatrix = new SimpleMatrix(this.anaphorMatrix);
        this.pairFeaturesMatrix = new SimpleMatrix(this.pairFeaturesMatrix);
        this.pairwiseFirstLayerBias = new SimpleMatrix(this.pairwiseFirstLayerBias);
        this.anaphoricityModel = this.anaphoricityModel.stream().map(x -> new SimpleMatrix(x)).collect(Collectors.toList());
        this.pairwiseModel = this.pairwiseModel.stream().map(x -> new SimpleMatrix(x)).collect(Collectors.toList());
    }

    public double getAnaphoricityScore(SimpleMatrix mentionEmbedding, SimpleMatrix anaphoricityFeatures) {
        return NeuralCorefModel.score(NeuralUtils.concatenate(mentionEmbedding, anaphoricityFeatures), this.anaphoricityModel);
    }

    public double getPairwiseScore(SimpleMatrix antecedentEmbedding, SimpleMatrix anaphorEmbedding, SimpleMatrix pairFeatures) {
        SimpleMatrix firstLayerOutput = NeuralUtils.elementwiseApplyReLU((SimpleMatrix)((SimpleMatrix)((SimpleMatrix)antecedentEmbedding.plus((SimpleBase)anaphorEmbedding)).plus(this.pairFeaturesMatrix.mult((SimpleBase)pairFeatures))).plus((SimpleBase)this.pairwiseFirstLayerBias));
        return NeuralCorefModel.score(firstLayerOutput, this.pairwiseModel);
    }

    private static double score(SimpleMatrix features, List<SimpleMatrix> weights) {
        for (int i = 0; i < weights.size(); i += 2) {
            features = (SimpleMatrix)((SimpleMatrix)weights.get(i).mult((SimpleBase)features)).plus((SimpleBase)weights.get(i + 1));
            if (weights.get(i).numRows() <= 1) continue;
            features = NeuralUtils.elementwiseApplyReLU(features);
        }
        return features.elementSum();
    }

    public SimpleMatrix getAnaphorEmbedding(SimpleMatrix mentionEmbedding) {
        return (SimpleMatrix)this.anaphorMatrix.mult((SimpleBase)mentionEmbedding);
    }

    public SimpleMatrix getAntecedentEmbedding(SimpleMatrix mentionEmbedding) {
        return (SimpleMatrix)this.antecedentMatrix.mult((SimpleBase)mentionEmbedding);
    }

    public Embedding getWordEmbeddings() {
        return this.wordEmbeddings;
    }
}

