/*
 * Decompiled with CFR 0.152.
 */
package org.ddogleg.clustering.gmm;

import org.ddogleg.clustering.AssignCluster;
import org.ddogleg.clustering.ComputeClusters;
import org.ddogleg.clustering.gmm.AssignGmm_F64;
import org.ddogleg.clustering.gmm.GaussianGmm_F64;
import org.ddogleg.clustering.gmm.GaussianLikelihoodManager;
import org.ddogleg.clustering.gmm.InitializeGmm_F64;
import org.ddogleg.struct.DogArray;
import org.ddogleg.struct.DogArray_F64;
import org.ddogleg.struct.LArrayAccessor;
import org.ejml.dense.row.CommonOps_DDRM;

public class ExpectationMaximizationGmm_F64
implements ComputeClusters<double[]> {
    InitializeGmm_F64 selectInitial;
    DogArray<GaussianGmm_F64> mixture;
    DogArray<PointInfo> info;
    int maxIterations;
    int pointDimension;
    double convergeTol;
    GaussianLikelihoodManager likelihoodManager;
    double[] dx = new double[1];
    double errorChiSquare;
    boolean verbose;

    public ExpectationMaximizationGmm_F64(int maxIterations, double convergeTol, int pointDimension, InitializeGmm_F64 selectInitial) {
        this.maxIterations = maxIterations;
        this.convergeTol = convergeTol;
        this.selectInitial = selectInitial;
        this.pointDimension = pointDimension;
        this.info = new DogArray<PointInfo>(() -> new PointInfo(pointDimension));
        this.mixture = new DogArray<GaussianGmm_F64>(() -> new GaussianGmm_F64(pointDimension));
        System.err.println("WARNING:  GMM-EM is a work in progress!  Might not work in your situation");
    }

    @Override
    public void initialize(long randomSeed) {
        this.selectInitial.init(this.pointDimension, randomSeed);
        if (this.dx.length < this.pointDimension) {
            this.dx = new double[this.pointDimension];
        }
        this.likelihoodManager = new GaussianLikelihoodManager(this.pointDimension, this.mixture.toList());
    }

    @Override
    public void process(LArrayAccessor<double[]> points, int numCluster) {
        this.mixture.resize(numCluster);
        this.info.resize(points.size());
        for (int i = 0; i < points.size(); ++i) {
            PointInfo p = (PointInfo)this.info.get(i);
            points.getCopy(i, p.point);
            p.weights.resize(numCluster);
        }
        if (this.verbose) {
            System.out.println("GMM-EM: Selecting initial seeds");
        }
        this.selectInitial.selectSeeds(points, this.mixture.toList());
        this.likelihoodManager.precomputeAll();
        if (this.verbose) {
            System.out.println("GMM-EM: Entering main loop");
        }
        double errorBefore = Double.MAX_VALUE;
        for (int iteration = 0; iteration < this.maxIterations; ++iteration) {
            double fractionChange;
            this.errorChiSquare = this.expectation();
            if (this.verbose) {
                System.out.println("GMM-EM: " + iteration + " errorChiSquare " + this.errorChiSquare);
            }
            if ((fractionChange = 1.0 - this.errorChiSquare / errorBefore) >= 0.0 && fractionChange <= this.convergeTol) {
                if (!this.verbose) break;
                System.out.println("GMM-EM: CONVERGED");
                break;
            }
            errorBefore = this.errorChiSquare;
            this.maximization();
            this.likelihoodManager.precomputeAll();
        }
        this.info.reset();
    }

    protected double expectation() {
        double sumChiSq = 0.0;
        for (int i = 0; i < this.info.size(); ++i) {
            int j;
            PointInfo p = (PointInfo)this.info.get(i);
            double bestLikelihood = 0.0;
            double bestChiSq = Double.MAX_VALUE;
            double total = 0.0;
            for (j = 0; j < this.mixture.size; ++j) {
                double likelihood;
                GaussianLikelihoodManager.Likelihood g = this.likelihoodManager.getLikelihood(j);
                p.weights.data[j] = likelihood = g.likelihood(p.point);
                total += p.weights.data[j];
                if (!(likelihood > bestLikelihood)) continue;
                bestLikelihood = likelihood;
                bestChiSq = g.getChisq();
            }
            if (total > 0.0) {
                j = 0;
                while (j < this.mixture.size) {
                    int n = j++;
                    p.weights.data[n] = p.weights.data[n] / total;
                }
            }
            sumChiSq += bestChiSq;
        }
        return sumChiSq;
    }

    protected void maximization() {
        int i;
        int i2;
        for (i2 = 0; i2 < this.mixture.size; ++i2) {
            ((GaussianGmm_F64)this.mixture.get(i2)).zero();
        }
        for (i2 = 0; i2 < this.info.size; ++i2) {
            PointInfo p = (PointInfo)this.info.get(i2);
            for (int j = 0; j < this.mixture.size; ++j) {
                ((GaussianGmm_F64)this.mixture.get(j)).addMean(p.point, p.weights.get(j));
            }
        }
        for (i2 = 0; i2 < this.mixture.size; ++i2) {
            GaussianGmm_F64 g = (GaussianGmm_F64)this.mixture.get(i2);
            if (!(g.weight > 0.0)) continue;
            CommonOps_DDRM.divide(g.mean, g.weight);
        }
        for (i2 = 0; i2 < this.info.size; ++i2) {
            PointInfo pp = (PointInfo)this.info.get(i2);
            double[] p = pp.point;
            for (int j = 0; j < this.mixture.size; ++j) {
                GaussianGmm_F64 g = (GaussianGmm_F64)this.mixture.get(j);
                for (int k = 0; k < p.length; ++k) {
                    this.dx[k] = p[k] - g.mean.data[k];
                }
                ((GaussianGmm_F64)this.mixture.get(j)).addCovariance(this.dx, pp.weights.get(j));
            }
        }
        double totalMixtureWeight = 0.0;
        for (i = 0; i < this.mixture.size; ++i) {
            GaussianGmm_F64 g = (GaussianGmm_F64)this.mixture.get(i);
            if (!(g.weight > 0.0)) continue;
            CommonOps_DDRM.divide(g.covariance, g.weight);
            totalMixtureWeight += g.weight;
        }
        for (i = 0; i < this.mixture.size; ++i) {
            ((GaussianGmm_F64)this.mixture.get((int)i)).weight /= totalMixtureWeight;
        }
    }

    @Override
    public AssignCluster<double[]> getAssignment() {
        return new AssignGmm_F64(this.mixture.toList());
    }

    @Override
    public double getDistanceMeasure() {
        return this.errorChiSquare;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.selectInitial.setVerbose(verbose);
        this.verbose = verbose;
    }

    @Override
    public ComputeClusters<double[]> newInstanceThread() {
        throw new RuntimeException("Not yet implemented");
    }

    public static class PointInfo {
        public double[] point;
        public DogArray_F64 weights = new DogArray_F64();

        public PointInfo(int pointDimension) {
            this.point = new double[pointDimension];
        }
    }
}

