/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.realtransform.inverse;

import net.imglib2.RealLocalizable;
import net.imglib2.RealPositionable;
import net.imglib2.realtransform.AffineTransform;
import net.imglib2.realtransform.RealTransform;
import net.imglib2.realtransform.inverse.DifferentiableRealTransform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InverseRealTransformGradientDescent
implements RealTransform {
    int ndims;
    AffineTransform jacobian;
    double[] directionalDeriv;
    double descentDirectionMag;
    double[] dir;
    double[] errorV;
    double[] estimate;
    double[] estimateXfm;
    double[] target;
    boolean fixZ = false;
    double error = 9999.0;
    double stepSz = 1.0;
    double beta = 0.5;
    double tolerance = 0.5;
    double c = 1.0E-4;
    int maxIters = 100;
    double jacobianEstimateStep = 1.0;
    double jacobianRegularizationEps = 0.1;
    int stepSizeMaxTries = 1000;
    double maxStepSize = Double.MAX_VALUE;
    double minStepSize = 1.0E-9;
    private DifferentiableRealTransform xfm;
    protected static Logger logger = LoggerFactory.getLogger(InverseRealTransformGradientDescent.class);
    private double[] srcd;
    private double[] tgtd;
    private double[] x_ap;
    private double[] phix_ap;

    public InverseRealTransformGradientDescent(int ndims, DifferentiableRealTransform xfm) {
        this.ndims = ndims;
        this.xfm = xfm;
        this.dir = new double[ndims];
        this.errorV = new double[ndims];
        this.directionalDeriv = new double[ndims];
        this.descentDirectionMag = 0.0;
        this.jacobian = new AffineTransform(ndims);
        this.target = new double[ndims];
        this.estimate = new double[ndims];
        this.estimateXfm = new double[ndims];
        this.srcd = new double[ndims];
        this.tgtd = new double[ndims];
        this.x_ap = new double[ndims];
        this.phix_ap = new double[ndims];
    }

    public void setBeta(double beta) {
        this.beta = beta;
    }

    public void setC(double c) {
        this.c = c;
    }

    public void setTolerance(double tol) {
        this.tolerance = tol;
    }

    public void setMaxIters(int maxIters) {
        this.maxIters = maxIters;
    }

    public void setFixZ(boolean fixZ) {
        this.fixZ = fixZ;
    }

    public void setStepSize(double stepSize) {
        this.stepSz = stepSize;
    }

    public void setMinStep(double minStep) {
        this.minStepSize = minStep;
    }

    public void setMaxStep(double maxStep) {
        this.maxStepSize = maxStep;
    }

    public void setJacobianEstimateStep(double jacStep) {
        this.jacobianEstimateStep = jacStep;
    }

    public void setJacobianRegularizationEps(double e) {
        this.jacobianRegularizationEps = e;
    }

    public void setStepSizeMaxTries(int stepSizeMaxTries) {
        this.stepSizeMaxTries = stepSizeMaxTries;
    }

    public void setTarget(double[] tgt) {
        System.arraycopy(tgt, 0, this.target, 0, this.ndims);
    }

    public double[] getErrorVector() {
        return this.errorV;
    }

    public double[] getDirection() {
        return this.dir;
    }

    public void setEstimate(double[] est) {
        System.arraycopy(est, 0, this.estimate, 0, this.ndims);
    }

    public void setEstimateXfm(double[] est) {
        System.arraycopy(est, 0, this.estimateXfm, 0, this.ndims);
    }

    public double[] getEstimate() {
        return this.estimate;
    }

    public double getError() {
        return this.error;
    }

    @Override
    public int numSourceDimensions() {
        return this.ndims;
    }

    @Override
    public int numTargetDimensions() {
        return this.ndims;
    }

    @Override
    public RealTransform copy() {
        InverseRealTransformGradientDescent copy = new InverseRealTransformGradientDescent(this.ndims, this.xfm.copy());
        copy.setBeta(this.beta);
        copy.setC(this.c);
        copy.setTolerance(this.tolerance);
        copy.setMaxIters(this.maxIters);
        return copy;
    }

    @Deprecated
    public void setGuess(double[] guess) {
    }

    @Override
    public void apply(double[] s, double[] t) {
        this.inverseTol(s, s, this.tolerance, this.maxIters);
        System.arraycopy(this.estimate, 0, t, 0, t.length);
    }

    @Override
    @Deprecated
    public void apply(float[] src, float[] tgt) {
        int i;
        for (i = 0; i < src.length; ++i) {
            this.srcd[i] = src[i];
        }
        this.apply(this.srcd, this.tgtd);
        for (i = 0; i < tgt.length; ++i) {
            tgt[i] = (float)this.tgtd[i];
        }
    }

    @Override
    public void apply(RealLocalizable src, RealPositionable tgt) {
        src.localize(this.srcd);
        this.apply(this.srcd, this.tgtd);
        tgt.setPosition(this.tgtd);
    }

    public double inverseTol(double[] target, double[] guess, double tolerance, int maxIters) {
        this.target = target;
        this.error = 999.0 * tolerance;
        this.setEstimate(guess);
        this.xfm.apply(this.estimate, this.estimateXfm);
        this.updateError();
        double t = 1.0;
        for (int k = 0; this.error >= tolerance && k < maxIters; ++k) {
            this.xfm.directionToward(this.dir, this.estimateXfm, target);
            t = this.backtrackingLineSearch(this.error);
            if (t == 0.0) break;
            this.updateEstimate(t);
            this.xfm.apply(this.estimate, this.estimateXfm);
            this.updateError();
            this.error = this.getError();
        }
        return this.error;
    }

    public void regularizeJacobian() {
        for (int i = 0; i < this.ndims; ++i) {
            this.jacobian.set(this.jacobianRegularizationEps + this.jacobian.get(i, i), i, i);
        }
    }

    public double backtrackingLineSearch(double t0) {
        double t = t0;
        for (int k = 0; k < this.stepSizeMaxTries && !this.armijoCondition(this.c, t); ++k) {
            t *= this.beta;
        }
        if (t < this.minStepSize) {
            return this.minStepSize;
        }
        if (t > this.maxStepSize) {
            return this.maxStepSize;
        }
        return t;
    }

    public double backtrackingLineSearch(double c, double beta, int maxtries, double t0) {
        double t = t0;
        for (int k = 0; k < maxtries && !this.armijoCondition(c, t); ++k) {
            t *= beta;
        }
        return t;
    }

    public boolean armijoCondition(double c, double t) {
        double[] d = this.dir;
        double[] x = this.estimate;
        for (int i = 0; i < this.ndims; ++i) {
            this.x_ap[i] = x[i] + t * d[i];
        }
        double[] phix = this.estimateXfm;
        this.xfm.apply(this.x_ap, this.phix_ap);
        double fx = this.squaredError(phix);
        double fx_ap = this.squaredError(this.phix_ap);
        double m = this.sumSquaredErrorsDeriv(this.target, phix) * this.descentDirectionMag;
        return fx_ap < fx + c * t * m;
    }

    public double squaredError(double[] x) {
        double error = 0.0;
        for (int i = 0; i < this.ndims; ++i) {
            error += (x[i] - this.target[i]) * (x[i] - this.target[i]);
        }
        return error;
    }

    public void updateEstimate(double stepSize) {
        for (int i = 0; i < this.ndims; ++i) {
            int n = i;
            this.estimate[n] = this.estimate[n] + stepSize * this.dir[i];
        }
    }

    public void updateError() {
        int i;
        if (this.estimate == null || this.target == null) {
            System.err.println("WARNING: Call to updateError with null target or estimate");
            return;
        }
        for (i = 0; i < this.ndims; ++i) {
            this.errorV[i] = this.target[i] - this.estimateXfm[i];
        }
        this.error = 0.0;
        for (i = 0; i < this.ndims; ++i) {
            this.error += this.errorV[i] * this.errorV[i];
        }
        this.error = Math.sqrt(this.error);
    }

    private double sumSquaredErrorsDeriv(double[] y, double[] x) {
        double errDeriv = 0.0;
        for (int i = 0; i < this.ndims; ++i) {
            errDeriv += (y[i] - x[i]) * (y[i] - x[i]);
        }
        return 2.0 * errDeriv;
    }

    public static double sumSquaredErrors(double[] y, double[] x) {
        int ndims = y.length;
        double err = 0.0;
        for (int i = 0; i < ndims; ++i) {
            err += (y[i] - x[i]) * (y[i] - x[i]);
        }
        return err;
    }
}

