/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.coref.neural;

import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class EmbeddingExtractor
implements Serializable {
    private static final long serialVersionUID = -663338564691488202L;
    private final boolean conll;
    private final Embedding staticWordEmbeddings;
    private final Embedding tunedWordEmbeddings;
    private final String naEmbedding;

    public EmbeddingExtractor(boolean conll, Embedding staticWordEmbeddings, Embedding tunedWordEmbeddings, String naEmbedding) {
        this.conll = conll;
        this.staticWordEmbeddings = staticWordEmbeddings;
        this.tunedWordEmbeddings = tunedWordEmbeddings;
        this.naEmbedding = naEmbedding;
    }

    public SimpleMatrix getDocumentEmbedding(Document document) {
        if (!this.conll) {
            return new SimpleMatrix(this.staticWordEmbeddings.getEmbeddingSize(), 1);
        }
        ArrayList<CoreLabel> words = new ArrayList<CoreLabel>();
        HashSet<Integer> seenSentences = new HashSet<Integer>();
        for (Mention m : document.predictedMentionsByID.values()) {
            if (seenSentences.contains(m.sentNum)) continue;
            seenSentences.add(m.sentNum);
            words.addAll(m.sentenceWords);
        }
        return this.getAverageEmbedding(words);
    }

    public SimpleMatrix getMentionEmbeddingsForFast(Mention m) {
        Iterator<SemanticGraphEdge> iterator = m.enhancedDependency.incomingEdgeIterator(m.headIndexedWord);
        SemanticGraphEdge relation = iterator.hasNext() ? iterator.next() : null;
        String depParent = relation == null ? "<missing>" : relation.getSource().word();
        return NeuralUtils.concatenate(this.getWordEmbedding(m.sentenceWords, m.startIndex - 2), this.getWordEmbedding(m.sentenceWords, m.startIndex - 1), this.getWordEmbedding(m.sentenceWords, m.startIndex), this.getWordEmbedding(m.sentenceWords, m.headIndex), this.getWordEmbedding(m.sentenceWords, m.endIndex - 1), this.getWordEmbedding(m.sentenceWords, m.endIndex), this.getWordEmbedding(m.sentenceWords, m.endIndex + 1), this.getWordEmbedding(depParent), this.getAverageEmbedding(m.sentenceWords.subList(m.startIndex, Math.min(m.endIndex, m.startIndex + 10))));
    }

    public SimpleMatrix getMentionEmbeddings(Mention m, SimpleMatrix docEmbedding) {
        Iterator<SemanticGraphEdge> depIterator = m.enhancedDependency.incomingEdgeIterator(m.headIndexedWord);
        SemanticGraphEdge depRelation = depIterator.hasNext() ? depIterator.next() : null;
        return NeuralUtils.concatenate(this.getAverageEmbedding(m.sentenceWords, m.startIndex, m.endIndex), this.getAverageEmbedding(m.sentenceWords, m.startIndex - 5, m.startIndex), this.getAverageEmbedding(m.sentenceWords, m.endIndex, m.endIndex + 5), this.getAverageEmbedding(m.sentenceWords.subList(0, m.sentenceWords.size() - 1)), docEmbedding, this.getWordEmbedding(m.sentenceWords, m.headIndex), this.getWordEmbedding(m.sentenceWords, m.startIndex), this.getWordEmbedding(m.sentenceWords, m.endIndex - 1), this.getWordEmbedding(m.sentenceWords, m.startIndex - 1), this.getWordEmbedding(m.sentenceWords, m.endIndex), this.getWordEmbedding(m.sentenceWords, m.startIndex - 2), this.getWordEmbedding(m.sentenceWords, m.endIndex + 1), this.getWordEmbedding(depRelation == null ? null : depRelation.getSource().word()));
    }

    private SimpleMatrix getAverageEmbedding(List<CoreLabel> words) {
        Embedding embeddings = this.staticWordEmbeddings == null ? this.tunedWordEmbeddings : this.staticWordEmbeddings;
        SimpleMatrix emb = new SimpleMatrix(embeddings.getEmbeddingSize(), 1);
        for (CoreLabel word : words) {
            String w = EmbeddingExtractor.normalizeWord(word.word());
            emb = (SimpleMatrix)emb.plus((SimpleBase)embeddings.get(w));
        }
        return (SimpleMatrix)emb.divide((double)Math.max(1, words.size()));
    }

    private SimpleMatrix getAverageEmbedding(List<CoreLabel> sentence, int start, int end) {
        return this.getAverageEmbedding(sentence.subList(Math.max(Math.min(start, sentence.size() - 1), 0), Math.max(Math.min(end, sentence.size() - 1), 0)));
    }

    private SimpleMatrix getWordEmbedding(List<CoreLabel> sentence, int i) {
        return this.getWordEmbedding(i < 0 || i >= sentence.size() ? this.naEmbedding : sentence.get(i).word());
    }

    public SimpleMatrix getWordEmbedding(String word) {
        word = EmbeddingExtractor.normalizeWord(word);
        if (this.staticWordEmbeddings == null) {
            return this.tunedWordEmbeddings.get(word);
        }
        return this.tunedWordEmbeddings.containsWord(word) ? this.tunedWordEmbeddings.get(word) : this.staticWordEmbeddings.get(word);
    }

    private static String normalizeWord(String w) {
        if (w == null) {
            return "<missing>";
        }
        if (w.equals("/.")) {
            return ".";
        }
        if (w.equals("/?")) {
            return "?";
        }
        if (w.equals("-LRB-")) {
            return "(";
        }
        if (w.equals("-RRB-")) {
            return ")";
        }
        if (w.equals("-LCB-")) {
            return "{";
        }
        if (w.equals("-RCB-")) {
            return "}";
        }
        if (w.equals("-LSB-")) {
            return "[";
        }
        if (w.equals("-RSB-")) {
            return "]";
        }
        if (w.equals("''")) {
            w = "\"";
        } else if (w.startsWith("%") && w.length() > 1) {
            w = w.substring(1);
        }
        return w.replaceAll("\\d", "0").toLowerCase();
    }
}

