/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.pixel_classification.gpu.algorithms;

import java.util.List;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuApi;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuPixelWiseOperation;
import sc.fiji.labkit.pixel_classification.gpu.api.GpuView;

public class GpuEigenvalues {
    public static void symmetric(GpuApi gpu, List<GpuView> matrix, List<GpuView> eigenvalues) {
        if (matrix.size() == 3 && eigenvalues.size() == 2) {
            GpuEigenvalues.symmetric2d(gpu, matrix.get(0), matrix.get(1), matrix.get(2), eigenvalues.get(0), eigenvalues.get(1));
        } else if (matrix.size() == 6 && eigenvalues.size() == 3) {
            GpuEigenvalues.symmetric3d(gpu, matrix.get(0), matrix.get(1), matrix.get(2), matrix.get(3), matrix.get(4), matrix.get(5), eigenvalues.get(0), eigenvalues.get(1), eigenvalues.get(2));
        } else {
            throw new UnsupportedOperationException();
        }
    }

    public static void symmetric2d(GpuApi gpu, GpuView xx, GpuView xy, GpuView yy, GpuView eigenvalue1, GpuView eigenvalue2) {
        GpuPixelWiseOperation.gpu(gpu).addInput("s_xx", xx).addInput("s_xy", xy).addInput("s_yy", yy).addOutput("l", eigenvalue1).addOutput("s", eigenvalue2).forEachPixel("float trace = s_xx + s_yy;l = (float) (trace / 2.0 + sqrt(4 * s_xy * s_xy + (s_xx - s_yy) * (s_xx - s_yy)) / 2.0);s = (float) (trace / 2.0 - sqrt(4 * s_xy * s_xy + (s_xx - s_yy) * (s_xx - s_yy)) / 2.0);");
    }

    public static void symmetric3d(GpuApi gpu, GpuView xx, GpuView xy, GpuView xz, GpuView yy, GpuView yz, GpuView zz, GpuView eigenvalue1, GpuView eigenvalue2, GpuView eigenvalue3) {
        GpuPixelWiseOperation.gpu(gpu).addInput("s_xx", xx).addInput("s_xy", xy).addInput("s_xz", xz).addInput("s_yy", yy).addInput("s_yz", yz).addInput("s_zz", zz).addOutput("large", eigenvalue1).addOutput("middle", eigenvalue2).addOutput("small", eigenvalue3).forEachPixel("double g_xx = s_xx, g_xy = s_xy, g_xz = s_xz, g_yy = s_yy, g_yz = s_yz, g_zz = s_zz;double a = -(g_xx + g_yy + g_zz);double b = g_xx * g_yy + g_xx * g_zz + g_yy * g_zz - g_xy * g_xy - g_xz * g_xz - g_yz * g_yz;double c = g_xx * (g_yz * g_yz - g_yy * g_zz) + g_yy * g_xz * g_xz + g_zz * g_xy * g_xy - 2 * g_xy * g_xz * g_yz;double x[3];solve_cubic_equation(c, b, a, x);large = (float) x[2];middle = (float) x[1];small = (float) x[0];");
    }
}

