/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFCliqueTree;
import edu.stanford.nlp.ie.crf.CRFLabel;
import edu.stanford.nlp.ie.crf.CliquePotentialFunction;
import edu.stanford.nlp.ie.crf.HasCliquePotentialFunction;
import edu.stanford.nlp.ie.crf.LinearCliquePotentialFunction;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffUpdateFunction;
import edu.stanford.nlp.optimization.HasFeatureGrouping;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;

public class CRFLogConditionalObjectiveFunction
extends AbstractStochasticCachingDiffUpdateFunction
implements HasCliquePotentialFunction,
HasFeatureGrouping {
    private static final Redwood.RedwoodChannels log = Redwood.channels(CRFLogConditionalObjectiveFunction.class);
    public static final int NO_PRIOR = 0;
    public static final int QUADRATIC_PRIOR = 1;
    public static final int HUBER_PRIOR = 2;
    public static final int QUARTIC_PRIOR = 3;
    public static final int DROPOUT_PRIOR = 4;
    public static final boolean DEBUG2 = false;
    public static final boolean DEBUG3 = false;
    public static final boolean TIMED = false;
    public static final boolean CONDENSE = true;
    public static boolean VERBOSE = false;
    protected final int prior;
    protected final double sigma;
    protected final double epsilon = 0.1;
    protected final List<Index<CRFLabel>> labelIndices;
    protected final Index<String> classIndex;
    protected final double[][] Ehat;
    protected final double[][] E;
    protected double[][][] parallelE;
    protected double[][][] parallelEhat;
    protected final int window;
    protected final int numClasses;
    protected final int[] map;
    protected int[][][][] data;
    protected double[][][][] featureVal;
    protected int[][] labels;
    protected final int domainDimension;
    protected int[][] weightIndices;
    protected final String backgroundSymbol;
    protected int[][] featureGrouping = null;
    protected static final double smallConst = 1.0E-6;
    protected Random rand = new Random(Integer.MAX_VALUE);
    protected final int multiThreadGrad;
    protected double[][] weights;
    protected CliquePotentialFunction cliquePotentialFunc;
    private ThreadsafeProcessor<TaskPart, TaskResult> expectedThreadProcessor = new ExpectationThreadsafeProcessor();
    private ThreadsafeProcessor<TaskPart, TaskResult> expectedAndEmpiricalThreadProcessor = new ExpectationThreadsafeProcessorWithEmpirical();

    @Override
    public double[] initial() {
        return this.initial(this.rand);
    }

    public double[] initial(boolean useRandomSeed) {
        Random randToUse = useRandomSeed ? new Random() : this.rand;
        return this.initial(randToUse);
    }

    public double[] initial(Random randGen) {
        double[] initial = new double[this.domainDimension()];
        for (int i = 0; i < initial.length; ++i) {
            initial[i] = randGen.nextDouble() + 1.0E-6;
        }
        return initial;
    }

    public static int getPriorType(String priorTypeStr) {
        if (priorTypeStr == null) {
            return 1;
        }
        if ("QUADRATIC".equalsIgnoreCase(priorTypeStr)) {
            return 1;
        }
        if ("HUBER".equalsIgnoreCase(priorTypeStr)) {
            return 2;
        }
        if ("QUARTIC".equalsIgnoreCase(priorTypeStr)) {
            return 3;
        }
        if ("DROPOUT".equalsIgnoreCase(priorTypeStr)) {
            return 4;
        }
        if ("NONE".equalsIgnoreCase(priorTypeStr)) {
            return 0;
        }
        if (priorTypeStr.equalsIgnoreCase("lasso") || priorTypeStr.equalsIgnoreCase("ridge") || priorTypeStr.equalsIgnoreCase("gaussian") || priorTypeStr.equalsIgnoreCase("ae-lasso") || priorTypeStr.equalsIgnoreCase("sg-lasso") || priorTypeStr.equalsIgnoreCase("g-lasso")) {
            return 0;
        }
        throw new IllegalArgumentException("Unknown prior type: " + priorTypeStr);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, int multiThreadGrad) {
        this(data, labels, window, classIndex, labelIndices, map, priorType, backgroundSymbol, sigma, featureVal, multiThreadGrad, true);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, int multiThreadGrad, boolean calcEmpirical) {
        this.window = window;
        this.classIndex = classIndex;
        this.numClasses = classIndex.size();
        this.labelIndices = labelIndices;
        this.map = map;
        this.data = data;
        this.featureVal = featureVal;
        this.labels = labels;
        this.prior = CRFLogConditionalObjectiveFunction.getPriorType(priorType);
        this.backgroundSymbol = backgroundSymbol;
        this.sigma = sigma;
        this.multiThreadGrad = multiThreadGrad;
        this.Ehat = this.empty2D();
        this.E = this.empty2D();
        this.weights = this.empty2D();
        if (calcEmpirical) {
            this.empiricalCounts(this.Ehat);
        }
        int myDomainDimension = 0;
        for (int dim : map) {
            myDomainDimension += labelIndices.get(dim).size();
        }
        this.domainDimension = myDomainDimension;
        log.info("Running gradient on " + multiThreadGrad + " threads");
    }

    protected void empiricalCounts(double[][] eHat) {
        for (int m = 0; m < this.data.length; ++m) {
            this.empiricalCountsForADoc(eHat, m);
        }
    }

    protected void empiricalCountsForADoc(double[][] eHat, int docIndex) {
        double[][][] featureValArr;
        int[][][] docData = this.data[docIndex];
        int[] docLabels = this.labels[docIndex];
        int[] windowLabels = new int[this.window];
        Arrays.fill(windowLabels, this.classIndex.indexOf(this.backgroundSymbol));
        double[][][] dArray = featureValArr = this.featureVal != null ? this.featureVal[docIndex] : (double[][][])null;
        if (docLabels.length > docData.length) {
            System.arraycopy(docLabels, 0, windowLabels, 0, windowLabels.length);
            docLabels = Arrays.copyOfRange(docLabels, docLabels.length - docData.length, docLabels.length);
        }
        for (int i = 0; i < docData.length; ++i) {
            System.arraycopy(windowLabels, 1, windowLabels, 0, this.window - 1);
            windowLabels[this.window - 1] = docLabels[i];
            int[][] docData_i = docData[i];
            for (int j = 0; j < docData_i.length; ++j) {
                int[] cliqueLabel = new int[j + 1];
                System.arraycopy(windowLabels, this.window - 1 - j, cliqueLabel, 0, j + 1);
                CRFLabel crfLabel = new CRFLabel(cliqueLabel);
                int labelIndex = this.labelIndices.get(j).indexOf(crfLabel);
                int[] docData_ij = docData_i[j];
                double[] featureValArr_ij = j == 0 && featureValArr != null ? featureValArr[i][j] : null;
                for (int n = 0; n < docData_ij.length; ++n) {
                    double[] dArray2 = eHat[docData_ij[n]];
                    int n2 = labelIndex;
                    dArray2[n2] = dArray2[n2] + (featureValArr_ij != null ? featureValArr_ij[n] : 1.0);
                }
            }
        }
    }

    @Override
    public CliquePotentialFunction getCliquePotentialFunction(double[] x) {
        this.to2D(x, this.weights);
        return new LinearCliquePotentialFunction(this.weights);
    }

    protected double expectedAndEmpiricalCountsAndValueForADoc(double[][] E, double[][] Ehat, int docIndex) {
        this.empiricalCountsForADoc(Ehat, docIndex);
        return this.expectedCountsAndValueForADoc(E, docIndex);
    }

    public double valueForADoc(int docIndex) {
        return this.expectedCountsAndValueForADoc(null, docIndex, false, true);
    }

    protected double expectedCountsAndValueForADoc(double[][] E, int docIndex) {
        return this.expectedCountsAndValueForADoc(E, docIndex, true, true);
    }

    protected double expectedCountsForADoc(double[][] E, int docIndex) {
        return this.expectedCountsAndValueForADoc(E, docIndex, true, false);
    }

    protected double expectedCountsAndValueForADoc(double[][] E, int docIndex, boolean doExpectedCountCalc, boolean doValueCalc) {
        double prob;
        int[][][] docData = this.data[docIndex];
        double[][][] featureVal3DArr = this.featureVal != null ? this.featureVal[docIndex] : (double[][][])null;
        CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, this.labelIndices, this.numClasses, this.classIndex, this.backgroundSymbol, this.cliquePotentialFunc, featureVal3DArr);
        double d = prob = doValueCalc ? this.documentLogProbability(docData, docIndex, cliqueTree) : 0.0;
        if (doExpectedCountCalc) {
            this.documentExpectedCounts(E, docData, featureVal3DArr, cliqueTree);
        }
        return prob;
    }

    protected void documentExpectedCounts(double[][] E, int[][][] docData, double[][][] featureVal3DArr, CRFCliqueTree<String> cliqueTree) {
        for (int i = 0; i < docData.length; ++i) {
            int[][] docData_i = docData[i];
            for (int j = 0; j < docData_i.length; ++j) {
                Index<CRFLabel> labelIndex = this.labelIndices.get(j);
                int[] docData_ij = docData_i[j];
                double[] featureValArr_ij = j == 0 && featureVal3DArr != null ? featureVal3DArr[i][j] : null;
                int liSize = labelIndex.size();
                for (int k = 0; k < liSize; ++k) {
                    int[] label = labelIndex.get(k).getLabel();
                    double p = cliqueTree.prob(i, label);
                    for (int n = 0; n < docData_ij.length; ++n) {
                        double[] dArray = E[docData_ij[n]];
                        int n2 = k;
                        dArray[n2] = dArray[n2] + (featureValArr_ij != null ? p * featureValArr_ij[n] : p);
                    }
                }
            }
        }
    }

    private double documentLogProbability(int[][][] docData, int docIndex, CRFCliqueTree<String> cliqueTree) {
        int[] docLabels = this.labels[docIndex];
        int[] given = new int[this.window - 1];
        Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
        if (docLabels.length > docData.length) {
            System.arraycopy(docLabels, 0, given, 0, given.length);
            docLabels = Arrays.copyOfRange(docLabels, docLabels.length - docData.length, docLabels.length);
        }
        double startPosLogProb = cliqueTree.logProbStartPos();
        if (VERBOSE || Double.isNaN(startPosLogProb)) {
            System.err.printf("P_-1(Background) = % 5.3f%n", startPosLogProb);
        }
        double prob = startPosLogProb;
        for (int i = 0; i < docData.length; ++i) {
            int label = docLabels[i];
            double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
            if (VERBOSE || Double.isNaN(p)) {
                log.info("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p);
            }
            prob += p;
            System.arraycopy(given, 1, given, 0, given.length - 1);
            given[given.length - 1] = label;
        }
        return prob;
    }

    public void setWeights(double[][] weights) {
        this.weights = weights;
        this.cliquePotentialFunc = new LinearCliquePotentialFunction(weights);
    }

    protected double regularGradientAndValue() {
        return this.multiThreadGradient(ArrayMath.range(0, this.data.length), false);
    }

    protected double multiThreadGradient(int[] docIDs, boolean calculateEmpirical) {
        int i;
        double objective = 0.0;
        if (this.multiThreadGrad <= 1) {
            return (calculateEmpirical ? this.expectedAndEmpiricalThreadProcessor : this.expectedThreadProcessor).process((TaskPart)new TaskPart((int)0, (int)0, (int)docIDs.length, (int[])docIDs)).objective;
        }
        if (this.parallelE == null) {
            this.parallelE = new double[this.multiThreadGrad][][];
            for (i = 0; i < this.multiThreadGrad; ++i) {
                this.parallelE[i] = this.empty2D();
            }
        }
        if (calculateEmpirical && this.parallelEhat == null) {
            this.parallelEhat = new double[this.multiThreadGrad][][];
            for (i = 0; i < this.multiThreadGrad; ++i) {
                this.parallelEhat[i] = this.empty2D();
            }
        }
        MulticoreWrapper<TaskPart, TaskResult> wrapper = new MulticoreWrapper<TaskPart, TaskResult>(this.multiThreadGrad, calculateEmpirical ? this.expectedAndEmpiricalThreadProcessor : this.expectedThreadProcessor);
        int totalLen = docIDs.length;
        int partLen = (totalLen + this.multiThreadGrad - 1) / this.multiThreadGrad;
        for (int part = 0; part < this.multiThreadGrad; ++part) {
            int currIndex = part * partLen;
            int endIndex = Math.min(currIndex + partLen, totalLen);
            wrapper.put(new TaskPart(part, currIndex, endIndex, docIDs));
        }
        wrapper.join();
        while (wrapper.peek()) {
            TaskResult result = wrapper.poll();
            int tID = result.id;
            objective += result.objective;
            CRFLogConditionalObjectiveFunction.combine2DArr(this.E, this.parallelE[tID]);
            if (!calculateEmpirical) continue;
            CRFLogConditionalObjectiveFunction.combine2DArr(this.Ehat, this.parallelEhat[tID]);
        }
        return objective;
    }

    @Override
    public void calculate(double[] x) {
        this.to2D(x, this.weights);
        this.setWeights(this.weights);
        CRFLogConditionalObjectiveFunction.clear2D(this.E);
        double prob = this.regularGradientAndValue();
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate() - this may well indicate numeric underflow due to overly long documents.");
        }
        this.value = -prob;
        if (VERBOSE) {
            log.info("value is " + Math.exp(-this.value));
        }
        int index = 0;
        for (int i = 0; i < this.E.length; ++i) {
            double[] E_i = this.E[i];
            double[] Ehat_i = this.Ehat[i];
            for (int j = 0; j < E_i.length; ++j) {
                this.derivative[index] = E_i[j] - Ehat_i[j];
                if (VERBOSE) {
                    log.info("deriv(" + i + "," + j + ") = " + E_i[j] + " - " + Ehat_i[j] + " = " + this.derivative[index]);
                }
                ++index;
            }
        }
        this.applyPrior(x, 1.0);
    }

    @Override
    public int dataDimension() {
        return this.data.length;
    }

    @Override
    public void calculateStochastic(double[] x, double[] v, int[] batch) {
        this.to2D(x, this.weights);
        this.setWeights(this.weights);
        double batchScale = (double)batch.length / (double)this.dataDimension();
        double prob = this.multiThreadGradient(batch, false);
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        int index = 0;
        for (int i = 0; i < this.E.length; ++i) {
            double[] E_i = this.E[i];
            double[] Ehat_i = this.Ehat[i];
            for (int j = 0; j < E_i.length; ++j) {
                this.derivative[index++] = E_i[j] - batchScale * Ehat_i[j];
                if (!VERBOSE) continue;
                log.info("deriv(" + i + "," + j + ") = " + E_i[j] + " - " + Ehat_i[j] + " = " + this.derivative[index - 1]);
            }
        }
        this.applyPrior(x, batchScale);
    }

    @Override
    public double calculateStochasticUpdate(double[] x, double xScale, int[] batch, double gScale) {
        ArrayMath.multiplyInPlace(x, xScale);
        this.to2D(x, this.weights);
        this.setWeights(this.weights);
        double prob = this.multiThreadGradient(batch, true);
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        int index = 0;
        for (int i = 0; i < this.E.length; ++i) {
            double[] E_i = this.E[i];
            double[] Ehat_i = this.Ehat[i];
            for (int j = 0; j < E_i.length; ++j) {
                int n = index++;
                x[n] = x[n] + (Ehat_i[j] - E_i[j]) * gScale;
            }
        }
        return this.value;
    }

    @Override
    public void calculateStochasticGradient(double[] x, int[] batch) {
        if (this.derivative == null) {
            this.derivative = new double[this.domainDimension()];
        }
        this.to2D(x, this.weights);
        this.setWeights(this.weights);
        this.multiThreadGradient(batch, true);
        int index = 0;
        for (int i = 0; i < this.E.length; ++i) {
            double[] Ei = this.E[i];
            double[] Ehati = this.Ehat[i];
            for (int j = 0; j < Ei.length; ++j) {
                this.derivative[index++] = Ei[j] - Ehati[j];
            }
        }
    }

    @Override
    public double valueAt(double[] x, double xScale, int[] batch) {
        double prob = 0.0;
        ArrayMath.multiplyInPlace(x, xScale);
        this.to2D(x, this.weights);
        this.setWeights(this.weights);
        for (int ind : batch) {
            prob += this.valueForADoc(ind);
        }
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        return this.value;
    }

    @Override
    public int[][] getFeatureGrouping() {
        Object object;
        if (this.featureGrouping != null) {
            object = this.featureGrouping;
        } else {
            int[][] nArrayArray = new int[1][];
            object = nArrayArray;
            nArrayArray[0] = ArrayMath.range(this.domainDimension());
        }
        return object;
    }

    public void setFeatureGrouping(int[][] fg) {
        this.featureGrouping = fg;
    }

    protected void applyPrior(double[] x, double batchScale) {
        block5: {
            block6: {
                block4: {
                    if (this.prior != 1) break block4;
                    double lambda = batchScale / (this.sigma * this.sigma);
                    int i = 0;
                    while (i < x.length) {
                        double w = x[i];
                        double wlambda = w * lambda;
                        this.value += w * wlambda * 0.5;
                        int n = i++;
                        this.derivative[n] = this.derivative[n] + wlambda;
                    }
                    break block5;
                }
                if (this.prior != 2) break block6;
                double batchScaleSigmaSq = batchScale / (this.sigma * this.sigma);
                for (int i = 0; i < x.length; ++i) {
                    double wabs;
                    double w = x[i];
                    double d = wabs = w < 0.0 ? -w : w;
                    if (wabs < 0.1) {
                        double weps = batchScaleSigmaSq * w / 0.1;
                        this.value += w * 0.5 * weps;
                        int n = i;
                        this.derivative[n] = this.derivative[n] + weps;
                        continue;
                    }
                    this.value += batchScaleSigmaSq * (wabs - 0.05);
                    int n = i;
                    this.derivative[n] = this.derivative[n] + (w < 0.0 ? -batchScaleSigmaSq : batchScaleSigmaSq);
                }
                break block5;
            }
            if (this.prior != 3) break block5;
            double sigmasq = this.sigma * this.sigma;
            double batchScaleSigmaQu = batchScale / (sigmasq * sigmasq);
            double lambda = 0.5 * batchScaleSigmaQu;
            int i = 0;
            while (i < x.length) {
                double w = x[i];
                double ww = w * w;
                this.value += ww * ww * lambda;
                int n = i++;
                this.derivative[n] = this.derivative[n] + batchScaleSigmaQu * w;
            }
        }
    }

    protected Pair<double[][][], double[][][]> getCondProbs(CRFCliqueTree<String> cTree, int[][][] docData) {
        double[][][] prevGivenCurr = new double[docData.length][this.numClasses][this.numClasses];
        double[][][] nextGivenCurr = new double[docData.length][this.numClasses][this.numClasses];
        for (int i = 0; i < docData.length; ++i) {
            double[][] prevGivenCurrI = prevGivenCurr[i];
            double[][] nextGivenCurrIm1 = i > 0 ? nextGivenCurr[i - 1] : (double[][])null;
            int[] labelPair = new int[2];
            for (int l1 = 0; l1 < this.numClasses; ++l1) {
                labelPair[0] = l1;
                for (int l2 = 0; l2 < this.numClasses; ++l2) {
                    labelPair[1] = l2;
                    double prob = cTree.logProb(i, labelPair);
                    if (i > 0) {
                        nextGivenCurrIm1[l1][l2] = prob;
                    }
                    prevGivenCurrI[l2][l1] = prob;
                }
            }
            for (int j = 0; j < this.numClasses; ++j) {
                double[] row;
                if (i > 0) {
                    row = nextGivenCurrIm1[j];
                    ArrayMath.logNormalize(row);
                    ArrayMath.expInPlace(row);
                }
                row = prevGivenCurrI[j];
                ArrayMath.logNormalize(row);
                ArrayMath.expInPlace(row);
            }
        }
        return new Pair<double[][][], double[][][]>(prevGivenCurr, nextGivenCurr);
    }

    protected static void combine2DArr(double[][] combineInto, double[][] toBeCombined, double scale) {
        for (int i = 0; i < toBeCombined.length; ++i) {
            double[] row = combineInto[i];
            double[] srcRow = toBeCombined[i];
            for (int j = 0; j < srcRow.length; ++j) {
                int n = j;
                row[n] = row[n] + srcRow[j] * scale;
            }
        }
    }

    protected static void combine2DArr(double[][] combineInto, double[][] toBeCombined) {
        for (int i = 0; i < toBeCombined.length; ++i) {
            double[] row = combineInto[i];
            double[] srcRow = toBeCombined[i];
            for (int j = 0; j < srcRow.length; ++j) {
                int n = j;
                row[n] = row[n] + srcRow[j];
            }
        }
    }

    protected static void combine2DArr(double[][] combineInto, Map<Integer, double[]> toBeCombined) {
        for (Map.Entry<Integer, double[]> entry : toBeCombined.entrySet()) {
            double[] row = combineInto[entry.getKey()];
            double[] source = entry.getValue();
            for (int i = 0; i < source.length; ++i) {
                int n = i;
                row[n] = row[n] + source[i];
            }
        }
    }

    protected static void combine2DArr(double[][] combineInto, Map<Integer, double[]> toBeCombined, double scale) {
        for (Map.Entry<Integer, double[]> entry : toBeCombined.entrySet()) {
            double[] row = combineInto[entry.getKey()];
            double[] source = entry.getValue();
            for (int i = 0; i < source.length; ++i) {
                int n = i;
                row[n] = row[n] + source[i] * scale;
            }
        }
    }

    @Override
    public int domainDimension() {
        return this.domainDimension;
    }

    public static double[][] to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map) {
        double[][] newWeights = new double[map.length][];
        int index = 0;
        for (int i = 0; i < map.length; ++i) {
            int labelSize = labelIndices.get(map[i]).size();
            newWeights[i] = new double[labelSize];
            System.arraycopy(weights, index, newWeights[i], 0, labelSize);
            index += labelSize;
        }
        return newWeights;
    }

    public double[][] to2D(double[] weights) {
        return CRFLogConditionalObjectiveFunction.to2D(weights, this.labelIndices, this.map);
    }

    public static void to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map, double[][] newWeights) {
        int index = 0;
        for (int i = 0; i < map.length; ++i) {
            int labelSize = labelIndices.get(map[i]).size();
            System.arraycopy(weights, index, newWeights[i], 0, labelSize);
            index += labelSize;
        }
    }

    public void to2D(double[] weights1D, double[][] newWeights) {
        CRFLogConditionalObjectiveFunction.to2D(weights1D, this.labelIndices, this.map, newWeights);
    }

    public static double[][] clear2D(double[][] arr2D) {
        for (int i = 0; i < arr2D.length; ++i) {
            Arrays.fill(arr2D[i], 0.0);
        }
        return arr2D;
    }

    public static void to1D(double[][] weights, double[] newWeights) {
        int index = 0;
        for (double[] weightVector : weights) {
            System.arraycopy(weightVector, 0, newWeights, index, weightVector.length);
            index += weightVector.length;
        }
    }

    public static double[] to1D(double[][] weights, int domainDimension) {
        double[] newWeights = new double[domainDimension];
        int index = 0;
        for (double[] weightVector : weights) {
            System.arraycopy(weightVector, 0, newWeights, index, weightVector.length);
            index += weightVector.length;
        }
        return newWeights;
    }

    public double[] to1D(double[][] weights) {
        return CRFLogConditionalObjectiveFunction.to1D(weights, this.domainDimension());
    }

    public int[][] getWeightIndices() {
        if (this.weightIndices == null) {
            this.weightIndices = new int[this.map.length][];
            int index = 0;
            for (int i = 0; i < this.map.length; ++i) {
                int labelSize = this.labelIndices.get(this.map[i]).size();
                this.weightIndices[i] = new int[labelSize];
                int[] row = this.weightIndices[i];
                for (int j = 0; j < labelSize; ++j) {
                    row[j] = index++;
                }
            }
        }
        return this.weightIndices;
    }

    protected double[][] empty2D() {
        double[][] d = new double[this.map.length][];
        for (int i = 0; i < this.map.length; ++i) {
            d[i] = new double[this.labelIndices.get(this.map[i]).size()];
        }
        return d;
    }

    public int[][] getLabels() {
        return this.labels;
    }

    class ExpectationThreadsafeProcessorWithEmpirical
    implements ThreadsafeProcessor<TaskPart, TaskResult> {
        ExpectationThreadsafeProcessorWithEmpirical() {
        }

        @Override
        public TaskResult process(TaskPart part) {
            double[][] partE = CRFLogConditionalObjectiveFunction.this.multiThreadGrad == 1 ? CRFLogConditionalObjectiveFunction.this.E : CRFLogConditionalObjectiveFunction.clear2D(CRFLogConditionalObjectiveFunction.this.parallelE[part.id]);
            double[][] partEhat = CRFLogConditionalObjectiveFunction.this.multiThreadGrad == 1 ? CRFLogConditionalObjectiveFunction.this.Ehat : CRFLogConditionalObjectiveFunction.clear2D(CRFLogConditionalObjectiveFunction.this.parallelEhat[part.id]);
            int begin = part.begin;
            int end = part.end;
            int[] docIds = part.docIds;
            double probSum = 0.0;
            for (int i = begin; i < end; ++i) {
                probSum += CRFLogConditionalObjectiveFunction.this.expectedAndEmpiricalCountsAndValueForADoc(partE, partEhat, docIds[i]);
            }
            return new TaskResult(part.id, probSum);
        }

        @Override
        public ThreadsafeProcessor<TaskPart, TaskResult> newInstance() {
            return this;
        }
    }

    class ExpectationThreadsafeProcessor
    implements ThreadsafeProcessor<TaskPart, TaskResult> {
        ExpectationThreadsafeProcessor() {
        }

        @Override
        public TaskResult process(TaskPart part) {
            double[][] partE = CRFLogConditionalObjectiveFunction.this.multiThreadGrad == 1 ? CRFLogConditionalObjectiveFunction.this.E : CRFLogConditionalObjectiveFunction.clear2D(CRFLogConditionalObjectiveFunction.this.parallelE[part.id]);
            int begin = part.begin;
            int end = part.end;
            int[] docIds = part.docIds;
            double probSum = 0.0;
            for (int i = begin; i < end; ++i) {
                probSum += CRFLogConditionalObjectiveFunction.this.expectedCountsAndValueForADoc(partE, docIds[i]);
            }
            return new TaskResult(part.id, probSum);
        }

        @Override
        public ThreadsafeProcessor<TaskPart, TaskResult> newInstance() {
            return this;
        }
    }

    private static class TaskResult {
        public int id;
        public double objective;

        public TaskResult(int id, double objective) {
            this.id = id;
            this.objective = objective;
        }
    }

    private static class TaskPart {
        public int id;
        public int begin;
        public int end;
        public int[] docIds;

        public TaskPart(int id, int begin, int end, int[] docIds) {
            this.id = id;
            this.begin = begin;
            this.end = end;
            this.docIds = docIds;
        }
    }
}

