package mltk.predictor.glm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Attribute;
import mltk.core.DenseVector;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.core.SparseVector;
import mltk.core.io.InstancesReader;
import mltk.predictor.Learner;
import mltk.predictor.io.PredictorWriter;
import mltk.util.OptimUtils;
import mltk.util.StatUtils;
import mltk.util.VectorUtils;
import weka.core.TestInstances;

/* loaded from: input_file:mltk/predictor/glm/LassoLearner.class */
public class LassoLearner extends Learner {
    private static /* synthetic */ int[] $SWITCH_TABLE$mltk$predictor$Learner$Task;
    private boolean verbose = false;
    private boolean fitIntercept = true;
    private boolean refit = false;
    private int maxNumIters = -1;
    private double epsilon = 1.0E-8d;
    private double lambda = 0.0d;
    private int numLambdas = -1;
    private Learner.Task task = Learner.Task.REGRESSION;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mltk/predictor/glm/LassoLearner$ModelStructure.class */
    public static class ModelStructure {
        boolean[] structure;

        ModelStructure(boolean[] zArr) {
            this.structure = zArr;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            return obj != null && getClass() == obj.getClass() && Arrays.equals(this.structure, ((ModelStructure) obj).structure);
        }

        public int hashCode() {
            return (31 * 1) + Arrays.hashCode(this.structure);
        }
    }

    /* loaded from: input_file:mltk/predictor/glm/LassoLearner$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path")
        String attPath = null;

        @Argument(name = "-t", description = "train set path", required = true)
        String trainPath = null;

        @Argument(name = "-o", description = "output model path")
        String outputModelPath = null;

        @Argument(name = "-g", description = "task between classification (c) and regression (r) (default: r)")
        String task = "r";

        @Argument(name = "-m", description = "maximum num of iterations (default: 0)")
        int maxIter = 0;

        @Argument(name = "-l", description = "lambda (default: 0)")
        double lambda = 0.0d;

        Options() {
        }
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        CmdLineParser cmdLineParser = new CmdLineParser(LassoLearner.class, options);
        Learner.Task task = null;
        try {
            cmdLineParser.parse(strArr);
            task = Learner.Task.getEnum(options.task);
        } catch (IllegalArgumentException e) {
            cmdLineParser.printUsage();
            System.exit(1);
        }
        Instances read = InstancesReader.read(options.attPath, options.trainPath);
        LassoLearner lassoLearner = new LassoLearner();
        lassoLearner.setVerbose(true);
        lassoLearner.setTask(task);
        lassoLearner.setLambda(options.lambda);
        lassoLearner.setMaxNumIters(options.maxIter);
        long currentTimeMillis = System.currentTimeMillis();
        GLM build = lassoLearner.build(read);
        System.out.println("Time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        if (options.outputModelPath != null) {
            PredictorWriter.write(build, options.outputModelPath);
        }
    }

    @Override // mltk.predictor.Learner
    public GLM build(Instances instances) {
        GLM glm = null;
        if (this.maxNumIters < 0) {
            this.maxNumIters = instances.dimension() * 20;
        }
        switch ($SWITCH_TABLE$mltk$predictor$Learner$Task()[this.task.ordinal()]) {
            case 1:
                glm = buildClassifier(instances, this.maxNumIters, this.lambda);
                break;
            case 2:
                glm = buildRegressor(instances, this.maxNumIters, this.lambda);
                break;
        }
        return glm;
    }

    public GLM buildBinaryClassifier(int[] iArr, double[][] dArr, double[] dArr2, int i, double d) {
        double[] dArr3 = new double[iArr.length];
        double d2 = 0.0d;
        double[] dArr4 = new double[dArr2.length];
        double[] dArr5 = new double[dArr2.length];
        OptimUtils.computePseudoResidual(dArr4, dArr2, dArr5);
        double[] dArr6 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr6[i2] = StatUtils.sumSq(dArr[i2]) / 4.0d;
        }
        double length = d * dArr2.length;
        for (int i3 = 0; i3 < i; i3++) {
            double computeLassoLoss = GLMOptimUtils.computeLassoLoss(dArr4, dArr2, dArr3, d);
            if (this.fitIntercept) {
                d2 += OptimUtils.fitIntercept(dArr4, dArr5, dArr2);
            }
            doOnePass(dArr, dArr6, dArr2, length, dArr3, dArr4, dArr5);
            double computeLassoLoss2 = GLMOptimUtils.computeLassoLoss(dArr4, dArr2, dArr3, d);
            if (this.verbose) {
                System.out.println("Iteration " + i3 + ": " + TestInstances.DEFAULT_SEPARATORS + computeLassoLoss2);
            }
            if (OptimUtils.isConverged(computeLassoLoss, computeLassoLoss2, this.epsilon)) {
                break;
            }
        }
        if (!this.refit) {
            return GLMOptimUtils.getGLM(iArr, dArr3, d2);
        }
        boolean[] zArr = new boolean[iArr.length];
        for (int i4 = 0; i4 < zArr.length; i4++) {
            zArr[i4] = dArr3[i4] != 0.0d;
        }
        return refitClassifier(iArr, zArr, dArr, dArr2, i);
    }

    public GLM buildBinaryClassifier(int[] iArr, int[][] iArr2, double[][] dArr, double[] dArr2, int i, double d) {
        double[] dArr3 = new double[iArr.length];
        double d2 = 0.0d;
        double[] dArr4 = new double[dArr2.length];
        double[] dArr5 = new double[dArr2.length];
        OptimUtils.computePseudoResidual(dArr4, dArr2, dArr5);
        double[] dArr6 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr6[i2] = StatUtils.sumSq(dArr[i2]) / 4.0d;
        }
        double length = d * dArr2.length;
        for (int i3 = 0; i3 < i; i3++) {
            double computeLassoLoss = GLMOptimUtils.computeLassoLoss(dArr4, dArr2, dArr3, d);
            if (this.fitIntercept) {
                d2 += OptimUtils.fitIntercept(dArr4, dArr5, dArr2);
            }
            doOnePass(iArr2, dArr, dArr6, dArr2, length, dArr3, dArr4, dArr5);
            double computeLassoLoss2 = GLMOptimUtils.computeLassoLoss(dArr4, dArr2, dArr3, d);
            if (this.verbose) {
                System.out.println("Iteration " + i3 + ": " + TestInstances.DEFAULT_SEPARATORS + computeLassoLoss2);
            }
            if (OptimUtils.isConverged(computeLassoLoss, computeLassoLoss2, this.epsilon)) {
                break;
            }
        }
        if (!this.refit) {
            return GLMOptimUtils.getGLM(iArr, dArr3, d2);
        }
        boolean[] zArr = new boolean[iArr.length];
        for (int i4 = 0; i4 < zArr.length; i4++) {
            zArr[i4] = dArr3[i4] != 0.0d;
        }
        return refitClassifier(iArr, zArr, iArr2, dArr, dArr2, i);
    }

    public List<GLM> buildBinaryClassifiers(int[] iArr, double[][] dArr, double[] dArr2, int i, int i2, double d) {
        double[] dArr3 = new double[iArr.length];
        double d2 = 0.0d;
        double[] dArr4 = new double[dArr2.length];
        double[] dArr5 = new double[dArr2.length];
        OptimUtils.computePseudoResidual(dArr4, dArr2, dArr5);
        double[] dArr6 = new double[dArr.length];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr6[i3] = StatUtils.sumSq(dArr[i3]) / 4.0d;
        }
        double findMaxLambda = findMaxLambda(dArr, dArr2, dArr4, dArr5);
        double pow = Math.pow(d, 1.0d / i2);
        ArrayList arrayList = new ArrayList(i2);
        HashSet hashSet = new HashSet();
        double d3 = findMaxLambda;
        for (int i4 = 0; i4 < i2; i4++) {
            double length = d3 * dArr2.length;
            for (int i5 = 0; i5 < i; i5++) {
                double computeLassoLoss = GLMOptimUtils.computeLassoLoss(dArr4, dArr2, dArr3, d3);
                if (this.fitIntercept) {
                    d2 += OptimUtils.fitIntercept(dArr4, dArr5, dArr2);
                }
                doOnePass(dArr, dArr6, dArr2, length, dArr3, dArr4, dArr5);
                double computeLassoLoss2 = GLMOptimUtils.computeLassoLoss(dArr4, dArr2, dArr3, d3);
                if (this.verbose) {
                    System.out.println("Iteration " + i5 + ": " + TestInstances.DEFAULT_SEPARATORS + computeLassoLoss2);
                }
                if (OptimUtils.isConverged(computeLassoLoss, computeLassoLoss2, this.epsilon)) {
                    break;
                }
            }
            d3 *= pow;
            if (this.refit) {
                boolean[] zArr = new boolean[iArr.length];
                for (int i6 = 0; i6 < zArr.length; i6++) {
                    zArr[i6] = dArr3[i6] != 0.0d;
                }
                ModelStructure modelStructure = new ModelStructure(zArr);
                if (!hashSet.contains(modelStructure)) {
                    arrayList.add(refitClassifier(iArr, zArr, dArr, dArr2, i));
                    hashSet.add(modelStructure);
                }
            } else {
                arrayList.add(GLMOptimUtils.getGLM(iArr, dArr3, d2));
            }
        }
        return arrayList;
    }

    public List<GLM> buildBinaryClassifiers(int[] iArr, int[][] iArr2, double[][] dArr, double[] dArr2, int i, int i2, double d) {
        double[] dArr3 = new double[iArr.length];
        double d2 = 0.0d;
        double[] dArr4 = new double[dArr2.length];
        double[] dArr5 = new double[dArr2.length];
        OptimUtils.computePseudoResidual(dArr4, dArr2, dArr5);
        double[] dArr6 = new double[dArr.length];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr6[i3] = StatUtils.sumSq(dArr[i3]) / 4.0d;
        }
        double findMaxLambda = findMaxLambda(iArr2, dArr, dArr2, dArr4, dArr5);
        double pow = Math.pow(d, 1.0d / i2);
        ArrayList arrayList = new ArrayList(i2);
        HashSet hashSet = new HashSet();
        double d3 = findMaxLambda;
        for (int i4 = 0; i4 < i2; i4++) {
            double length = d3 * dArr2.length;
            for (int i5 = 0; i5 < i; i5++) {
                double computeLassoLoss = GLMOptimUtils.computeLassoLoss(dArr4, dArr2, dArr3, d3);
                if (this.fitIntercept) {
                    d2 += OptimUtils.fitIntercept(dArr4, dArr5, dArr2);
                }
                doOnePass(iArr2, dArr, dArr6, dArr2, length, dArr3, dArr4, dArr5);
                double computeLassoLoss2 = GLMOptimUtils.computeLassoLoss(dArr4, dArr2, dArr3, d3);
                if (this.verbose) {
                    System.out.println("Iteration " + i5 + ": " + TestInstances.DEFAULT_SEPARATORS + computeLassoLoss2);
                }
                if (OptimUtils.isConverged(computeLassoLoss, computeLassoLoss2, this.epsilon)) {
                    break;
                }
            }
            d3 *= pow;
            if (this.refit) {
                boolean[] zArr = new boolean[iArr.length];
                for (int i6 = 0; i6 < zArr.length; i6++) {
                    zArr[i6] = dArr3[i6] != 0.0d;
                }
                ModelStructure modelStructure = new ModelStructure(zArr);
                if (!hashSet.contains(modelStructure)) {
                    arrayList.add(refitClassifier(iArr, zArr, iArr2, dArr, dArr2, i));
                    hashSet.add(modelStructure);
                }
            } else {
                arrayList.add(GLMOptimUtils.getGLM(iArr, dArr3, d2));
            }
        }
        return arrayList;
    }

    public GLM buildClassifier(Instances instances, boolean z, int i, double d) {
        Attribute targetAttribute = instances.getTargetAttribute();
        if (targetAttribute.getType() != Attribute.Type.NOMINAL) {
            throw new IllegalArgumentException("Class attribute must be nominal.");
        }
        int length = ((NominalAttribute) targetAttribute).getStates().length;
        if (!z) {
            Learner.DenseDataset denseDataset = getDenseDataset(instances, true);
            int[] iArr = denseDataset.attrs;
            double[][] dArr = denseDataset.x;
            double[] dArr2 = new double[denseDataset.y.length];
            double[] dArr3 = denseDataset.cList;
            if (length == 2) {
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    dArr2[i2] = ((int) denseDataset.y[i2]) == 0 ? 1 : 0;
                }
                GLM buildBinaryClassifier = buildBinaryClassifier(iArr, dArr, dArr2, i, d);
                double[] dArr4 = buildBinaryClassifier.w[0];
                for (int i3 = 0; i3 < dArr3.length; i3++) {
                    int i4 = iArr[i3];
                    dArr4[i4] = dArr4[i4] * dArr3[i3];
                }
                return buildBinaryClassifier;
            }
            GLM glm = new GLM(length, iArr.length == 0 ? 0 : StatUtils.max(iArr) + 1);
            int i5 = 0;
            while (i5 < length) {
                for (int i6 = 0; i6 < dArr2.length; i6++) {
                    dArr2[i6] = ((int) denseDataset.y[i6]) == i5 ? 1 : 0;
                }
                GLM buildBinaryClassifier2 = buildBinaryClassifier(iArr, dArr, dArr2, i, d);
                double[] dArr5 = buildBinaryClassifier2.w[0];
                for (int i7 = 0; i7 < dArr3.length; i7++) {
                    int i8 = iArr[i7];
                    glm.w[i5][i8] = dArr5[i8] * dArr3[i7];
                }
                glm.intercept[i5] = buildBinaryClassifier2.intercept[0];
                i5++;
            }
            return glm;
        }
        Learner.SparseDataset sparseDataset = getSparseDataset(instances, true);
        int[] iArr2 = sparseDataset.attrs;
        int[][] iArr3 = sparseDataset.indices;
        double[][] dArr6 = sparseDataset.values;
        double[] dArr7 = new double[sparseDataset.y.length];
        double[] dArr8 = sparseDataset.cList;
        if (length == 2) {
            for (int i9 = 0; i9 < dArr7.length; i9++) {
                dArr7[i9] = ((int) sparseDataset.y[i9]) == 0 ? 1 : 0;
            }
            GLM buildBinaryClassifier3 = buildBinaryClassifier(iArr2, iArr3, dArr6, dArr7, i, d);
            double[] dArr9 = buildBinaryClassifier3.w[0];
            for (int i10 = 0; i10 < dArr8.length; i10++) {
                int i11 = sparseDataset.attrs[i10];
                dArr9[i11] = dArr9[i11] * dArr8[i10];
            }
            return buildBinaryClassifier3;
        }
        GLM glm2 = new GLM(length, iArr2.length == 0 ? 0 : StatUtils.max(iArr2) + 1);
        int i12 = 0;
        while (i12 < length) {
            for (int i13 = 0; i13 < dArr7.length; i13++) {
                dArr7[i13] = ((int) sparseDataset.y[i13]) == i12 ? 1 : 0;
            }
            GLM buildBinaryClassifier4 = buildBinaryClassifier(iArr2, iArr3, dArr6, dArr7, i, d);
            double[] dArr10 = buildBinaryClassifier4.w[0];
            for (int i14 = 0; i14 < dArr8.length; i14++) {
                int i15 = iArr2[i14];
                glm2.w[i12][i15] = dArr10[i15] * dArr8[i14];
            }
            glm2.intercept[i12] = buildBinaryClassifier4.intercept[0];
            i12++;
        }
        return glm2;
    }

    public GLM buildClassifier(Instances instances, int i, double d) {
        return buildClassifier(instances, isSparse(instances), i, d);
    }

    public List<GLM> buildClassifiers(Instances instances, boolean z, int i, int i2, double d) {
        Attribute targetAttribute = instances.getTargetAttribute();
        if (targetAttribute.getType() != Attribute.Type.NOMINAL) {
            throw new IllegalArgumentException("Class attribute must be nominal.");
        }
        int length = ((NominalAttribute) targetAttribute).getStates().length;
        if (!z) {
            Learner.DenseDataset denseDataset = getDenseDataset(instances, true);
            int[] iArr = denseDataset.attrs;
            double[][] dArr = denseDataset.x;
            double[] dArr2 = new double[denseDataset.y.length];
            double[] dArr3 = denseDataset.cList;
            if (length == 2) {
                for (int i3 = 0; i3 < dArr2.length; i3++) {
                    dArr2[i3] = ((int) denseDataset.y[i3]) == 0 ? 1 : 0;
                }
                List<GLM> buildBinaryClassifiers = buildBinaryClassifiers(iArr, dArr, dArr2, i, i2, d);
                Iterator<GLM> it = buildBinaryClassifiers.iterator();
                while (it.hasNext()) {
                    double[] dArr4 = it.next().w[0];
                    for (int i4 = 0; i4 < dArr3.length; i4++) {
                        int i5 = denseDataset.attrs[i4];
                        dArr4[i5] = dArr4[i5] * dArr3[i4];
                    }
                }
                return buildBinaryClassifiers;
            }
            boolean z2 = this.refit;
            this.refit = false;
            int max = iArr.length == 0 ? 0 : StatUtils.max(iArr) + 1;
            ArrayList arrayList = new ArrayList();
            for (int i6 = 0; i6 < i2; i6++) {
                arrayList.add(new GLM(length, max));
            }
            int i7 = 0;
            while (i7 < length) {
                for (int i8 = 0; i8 < dArr2.length; i8++) {
                    dArr2[i8] = ((int) denseDataset.y[i8]) == i7 ? 1 : 0;
                }
                List<GLM> buildBinaryClassifiers2 = buildBinaryClassifiers(iArr, dArr, dArr2, i, i2, d);
                for (int i9 = 0; i9 < i2; i9++) {
                    GLM glm = buildBinaryClassifiers2.get(i9);
                    GLM glm2 = (GLM) arrayList.get(i9);
                    double[] dArr5 = glm.w[0];
                    for (int i10 = 0; i10 < dArr3.length; i10++) {
                        int i11 = iArr[i10];
                        glm2.w[i7][i11] = dArr5[i11] * dArr3[i10];
                    }
                    glm2.intercept[i7] = glm.intercept[0];
                }
                i7++;
            }
            this.refit = z2;
            return arrayList;
        }
        Learner.SparseDataset sparseDataset = getSparseDataset(instances, true);
        int[] iArr2 = sparseDataset.attrs;
        int[][] iArr3 = sparseDataset.indices;
        double[][] dArr6 = sparseDataset.values;
        double[] dArr7 = new double[sparseDataset.y.length];
        double[] dArr8 = sparseDataset.cList;
        if (length == 2) {
            for (int i12 = 0; i12 < dArr7.length; i12++) {
                dArr7[i12] = ((int) sparseDataset.y[i12]) == 0 ? 1 : 0;
            }
            List<GLM> buildBinaryClassifiers3 = buildBinaryClassifiers(iArr2, iArr3, dArr6, dArr7, i, i2, d);
            Iterator<GLM> it2 = buildBinaryClassifiers3.iterator();
            while (it2.hasNext()) {
                double[] dArr9 = it2.next().w[0];
                for (int i13 = 0; i13 < dArr8.length; i13++) {
                    int i14 = sparseDataset.attrs[i13];
                    dArr9[i14] = dArr9[i14] * dArr8[i13];
                }
            }
            return buildBinaryClassifiers3;
        }
        boolean z3 = this.refit;
        this.refit = false;
        int max2 = iArr2.length == 0 ? 0 : StatUtils.max(iArr2) + 1;
        ArrayList arrayList2 = new ArrayList();
        for (int i15 = 0; i15 < i2; i15++) {
            arrayList2.add(new GLM(length, max2));
        }
        int i16 = 0;
        while (i16 < length) {
            for (int i17 = 0; i17 < dArr7.length; i17++) {
                dArr7[i17] = ((int) sparseDataset.y[i17]) == i16 ? 1 : 0;
            }
            List<GLM> buildBinaryClassifiers4 = buildBinaryClassifiers(iArr2, iArr3, dArr6, dArr7, i, i2, d);
            for (int i18 = 0; i18 < i2; i18++) {
                GLM glm3 = buildBinaryClassifiers4.get(i18);
                GLM glm4 = (GLM) arrayList2.get(i18);
                double[] dArr10 = glm3.w[0];
                for (int i19 = 0; i19 < dArr8.length; i19++) {
                    int i20 = iArr2[i19];
                    glm4.w[i16][i20] = dArr10[i20] * dArr8[i19];
                }
                glm4.intercept[i16] = glm3.intercept[0];
            }
            i16++;
        }
        this.refit = z3;
        return arrayList2;
    }

    public List<GLM> buildClassifiers(Instances instances, int i, int i2, double d) {
        return buildClassifiers(instances, isSparse(instances), i, i2, d);
    }

    public GLM buildRegressor(Instances instances, boolean z, int i, double d) {
        if (z) {
            Learner.SparseDataset sparseDataset = getSparseDataset(instances, true);
            double[] dArr = sparseDataset.cList;
            GLM buildRegressor = buildRegressor(sparseDataset.attrs, sparseDataset.indices, sparseDataset.values, sparseDataset.y, i, d);
            double[] dArr2 = buildRegressor.w[0];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = sparseDataset.attrs[i2];
                dArr2[i3] = dArr2[i3] * dArr[i2];
            }
            return buildRegressor;
        }
        Learner.DenseDataset denseDataset = getDenseDataset(instances, true);
        double[] dArr3 = denseDataset.cList;
        GLM buildRegressor2 = buildRegressor(denseDataset.attrs, denseDataset.x, denseDataset.y, i, d);
        double[] dArr4 = buildRegressor2.w[0];
        for (int i4 = 0; i4 < dArr3.length; i4++) {
            int i5 = denseDataset.attrs[i4];
            dArr4[i5] = dArr4[i5] * dArr3[i4];
        }
        return buildRegressor2;
    }

    public GLM buildRegressor(Instances instances, int i, double d) {
        return buildRegressor(instances, isSparse(instances), i, d);
    }

    public GLM buildRegressor(int[] iArr, double[][] dArr, double[] dArr2, int i, double d) {
        double[] dArr3 = new double[iArr.length];
        double d2 = 0.0d;
        double[] dArr4 = new double[dArr2.length];
        for (int i2 = 0; i2 < dArr4.length; i2++) {
            dArr4[i2] = dArr2[i2];
        }
        double[] dArr5 = new double[dArr.length];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr5[i3] = StatUtils.sumSq(dArr[i3]);
        }
        double length = d * dArr2.length;
        for (int i4 = 0; i4 < i; i4++) {
            double computeLassoLoss = GLMOptimUtils.computeLassoLoss(dArr4, dArr3, d);
            if (this.fitIntercept) {
                d2 += OptimUtils.fitIntercept(dArr4);
            }
            doOnePass(dArr, dArr5, length, dArr3, dArr4);
            double computeLassoLoss2 = GLMOptimUtils.computeLassoLoss(dArr4, dArr3, d);
            if (this.verbose) {
                System.out.println("Iteration " + i4 + ": " + TestInstances.DEFAULT_SEPARATORS + computeLassoLoss2);
            }
            if (OptimUtils.isConverged(computeLassoLoss, computeLassoLoss2, this.epsilon)) {
                break;
            }
        }
        if (!this.refit) {
            return GLMOptimUtils.getGLM(iArr, dArr3, d2);
        }
        boolean[] zArr = new boolean[iArr.length];
        for (int i5 = 0; i5 < zArr.length; i5++) {
            zArr[i5] = dArr3[i5] != 0.0d;
        }
        return refitRegressor(iArr, zArr, dArr, dArr2, i);
    }

    public GLM buildRegressor(int[] iArr, int[][] iArr2, double[][] dArr, double[] dArr2, int i, double d) {
        double[] dArr3 = new double[iArr.length];
        double d2 = 0.0d;
        double[] dArr4 = new double[dArr2.length];
        for (int i2 = 0; i2 < dArr4.length; i2++) {
            dArr4[i2] = dArr2[i2];
        }
        double[] dArr5 = new double[dArr.length];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr5[i3] = StatUtils.sumSq(dArr[i3]);
        }
        double length = d * dArr2.length;
        for (int i4 = 0; i4 < i; i4++) {
            double computeLassoLoss = GLMOptimUtils.computeLassoLoss(dArr4, dArr3, d);
            if (this.fitIntercept) {
                d2 += OptimUtils.fitIntercept(dArr4);
            }
            doOnePass(iArr2, dArr, dArr5, length, dArr3, dArr4);
            double computeLassoLoss2 = GLMOptimUtils.computeLassoLoss(dArr4, dArr3, d);
            if (this.verbose) {
                System.out.println("Iteration " + i4 + ": " + TestInstances.DEFAULT_SEPARATORS + computeLassoLoss2);
            }
            if (OptimUtils.isConverged(computeLassoLoss, computeLassoLoss2, this.epsilon)) {
                break;
            }
        }
        if (!this.refit) {
            return GLMOptimUtils.getGLM(iArr, dArr3, d2);
        }
        boolean[] zArr = new boolean[iArr.length];
        for (int i5 = 0; i5 < zArr.length; i5++) {
            zArr[i5] = dArr3[i5] != 0.0d;
        }
        return refitRegressor(iArr, zArr, iArr2, dArr, dArr2, i);
    }

    public List<GLM> buildRegressors(Instances instances, boolean z, int i, int i2, double d) {
        if (z) {
            Learner.SparseDataset sparseDataset = getSparseDataset(instances, true);
            double[] dArr = sparseDataset.cList;
            List<GLM> buildRegressors = buildRegressors(sparseDataset.attrs, sparseDataset.indices, sparseDataset.values, sparseDataset.y, i, i2, d);
            Iterator<GLM> it = buildRegressors.iterator();
            while (it.hasNext()) {
                double[] dArr2 = it.next().w[0];
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    int i4 = sparseDataset.attrs[i3];
                    dArr2[i4] = dArr2[i4] * dArr[i3];
                }
            }
            return buildRegressors;
        }
        Learner.DenseDataset denseDataset = getDenseDataset(instances, true);
        double[] dArr3 = denseDataset.cList;
        List<GLM> buildRegressors2 = buildRegressors(denseDataset.attrs, denseDataset.x, denseDataset.y, i, i2, d);
        Iterator<GLM> it2 = buildRegressors2.iterator();
        while (it2.hasNext()) {
            double[] dArr4 = it2.next().w[0];
            for (int i5 = 0; i5 < dArr3.length; i5++) {
                int i6 = denseDataset.attrs[i5];
                dArr4[i6] = dArr4[i6] * dArr3[i5];
            }
        }
        return buildRegressors2;
    }

    public List<GLM> buildRegressors(Instances instances, int i, int i2, double d) {
        return buildRegressors(instances, isSparse(instances), i, i2, d);
    }

    public List<GLM> buildRegressors(int[] iArr, double[][] dArr, double[] dArr2, int i, int i2, double d) {
        double[] dArr3 = new double[iArr.length];
        double d2 = 0.0d;
        double[] dArr4 = new double[dArr2.length];
        for (int i3 = 0; i3 < dArr4.length; i3++) {
            dArr4[i3] = dArr2[i3];
        }
        double[] dArr5 = new double[dArr.length];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr5[i4] = StatUtils.sumSq(dArr[i4]);
        }
        double findMaxLambda = findMaxLambda(dArr, dArr4);
        double pow = Math.pow(d, 1.0d / i2);
        ArrayList arrayList = new ArrayList(i2);
        HashSet hashSet = new HashSet();
        double d3 = findMaxLambda;
        for (int i5 = 0; i5 < i2; i5++) {
            double length = d3 * dArr2.length;
            for (int i6 = 0; i6 < i; i6++) {
                double computeLassoLoss = GLMOptimUtils.computeLassoLoss(dArr4, dArr3, d3);
                if (this.fitIntercept) {
                    d2 += OptimUtils.fitIntercept(dArr4);
                }
                doOnePass(dArr, dArr5, length, dArr3, dArr4);
                double computeLassoLoss2 = GLMOptimUtils.computeLassoLoss(dArr4, dArr3, d3);
                if (this.verbose) {
                    System.out.println("Iteration " + i6 + ": " + TestInstances.DEFAULT_SEPARATORS + computeLassoLoss2);
                }
                if (OptimUtils.isConverged(computeLassoLoss, computeLassoLoss2, this.epsilon)) {
                    break;
                }
            }
            d3 *= pow;
            if (this.refit) {
                boolean[] zArr = new boolean[iArr.length];
                for (int i7 = 0; i7 < zArr.length; i7++) {
                    zArr[i7] = dArr3[i7] != 0.0d;
                }
                ModelStructure modelStructure = new ModelStructure(zArr);
                if (!hashSet.contains(modelStructure)) {
                    arrayList.add(refitRegressor(iArr, zArr, dArr, dArr2, i));
                    hashSet.add(modelStructure);
                }
            } else {
                arrayList.add(GLMOptimUtils.getGLM(iArr, dArr3, d2));
            }
        }
        return arrayList;
    }

    public List<GLM> buildRegressors(int[] iArr, int[][] iArr2, double[][] dArr, double[] dArr2, int i, int i2, double d) {
        double[] dArr3 = new double[iArr.length];
        double d2 = 0.0d;
        double[] dArr4 = new double[dArr2.length];
        for (int i3 = 0; i3 < dArr4.length; i3++) {
            dArr4[i3] = dArr2[i3];
        }
        double[] dArr5 = new double[dArr.length];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr5[i4] = StatUtils.sumSq(dArr[i4]);
        }
        double findMaxLambda = findMaxLambda(iArr2, dArr, dArr2);
        double pow = Math.pow(d, 1.0d / i2);
        ArrayList arrayList = new ArrayList(i2);
        HashSet hashSet = new HashSet();
        double d3 = findMaxLambda;
        for (int i5 = 0; i5 < i2; i5++) {
            double length = d3 * dArr2.length;
            for (int i6 = 0; i6 < i; i6++) {
                double computeLassoLoss = GLMOptimUtils.computeLassoLoss(dArr4, dArr3, d3);
                if (this.fitIntercept) {
                    d2 += OptimUtils.fitIntercept(dArr4);
                }
                doOnePass(iArr2, dArr, dArr5, length, dArr3, dArr4);
                double computeLassoLoss2 = GLMOptimUtils.computeLassoLoss(dArr4, dArr3, d3);
                if (this.verbose) {
                    System.out.println("Iteration " + i6 + ": " + TestInstances.DEFAULT_SEPARATORS + computeLassoLoss2);
                }
                if (OptimUtils.isConverged(computeLassoLoss, computeLassoLoss2, this.epsilon)) {
                    break;
                }
            }
            d3 *= pow;
            if (this.refit) {
                boolean[] zArr = new boolean[iArr.length];
                for (int i7 = 0; i7 < zArr.length; i7++) {
                    zArr[i7] = dArr3[i7] != 0.0d;
                }
                ModelStructure modelStructure = new ModelStructure(zArr);
                if (!hashSet.contains(modelStructure)) {
                    arrayList.add(refitRegressor(iArr, zArr, iArr2, dArr, dArr2, i));
                    hashSet.add(modelStructure);
                }
            } else {
                arrayList.add(GLMOptimUtils.getGLM(iArr, dArr3, d2));
            }
        }
        return arrayList;
    }

    protected void doOnePass(double[][] dArr, double[] dArr2, double d, double[] dArr3, double[] dArr4) {
        for (int i = 0; i < dArr.length; i++) {
            double[] dArr5 = dArr[i];
            double dotProduct = (dArr3[i] * dArr2[i]) + VectorUtils.dotProduct(dArr5, dArr4);
            double d2 = (Math.abs(dotProduct) <= d ? 0.0d : dotProduct > 0.0d ? dotProduct - d : dotProduct + d) / dArr2[i];
            double d3 = d2 - dArr3[i];
            dArr3[i] = d2;
            for (int i2 = 0; i2 < dArr4.length; i2++) {
                int i3 = i2;
                dArr4[i3] = dArr4[i3] - (d3 * dArr5[i2]);
            }
        }
    }

    protected void doOnePass(double[][] dArr, double[] dArr2, double[] dArr3, double d, double[] dArr4, double[] dArr5, double[] dArr6) {
        for (int i = 0; i < dArr.length; i++) {
            if (Math.abs(dArr2[i]) > 1.0E-8d) {
                double[] dArr7 = dArr[i];
                double dotProduct = dArr4[i] + (VectorUtils.dotProduct(dArr6, dArr7) / dArr2[i]);
                double d2 = d / dArr2[i];
                double d3 = dotProduct > d2 ? dotProduct - d2 : dotProduct < (-d2) ? dotProduct + d2 : 0.0d;
                double d4 = d3 - dArr4[i];
                dArr4[i] = d3;
                for (int i2 = 0; i2 < dArr5.length; i2++) {
                    int i3 = i2;
                    dArr5[i3] = dArr5[i3] + (d4 * dArr7[i2]);
                    dArr6[i2] = OptimUtils.getPseudoResidual(dArr5[i2], dArr3[i2]);
                }
            }
        }
    }

    protected void doOnePass(int[][] iArr, double[][] dArr, double[] dArr2, double d, double[] dArr3, double[] dArr4) {
        for (int i = 0; i < iArr.length; i++) {
            double d2 = dArr3[i] * dArr2[i];
            int[] iArr2 = iArr[i];
            double[] dArr5 = dArr[i];
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                d2 += dArr4[iArr2[i2]] * dArr5[i2];
            }
            double d3 = (Math.abs(d2) <= d ? 0.0d : d2 > 0.0d ? d2 - d : d2 + d) / dArr2[i];
            double d4 = d3 - dArr3[i];
            dArr3[i] = d3;
            for (int i3 = 0; i3 < iArr2.length; i3++) {
                int i4 = iArr2[i3];
                dArr4[i4] = dArr4[i4] - (d4 * dArr5[i3]);
            }
        }
    }

    protected void doOnePass(int[][] iArr, double[][] dArr, double[] dArr2, double[] dArr3, double d, double[] dArr4, double[] dArr5, double[] dArr6) {
        for (int i = 0; i < iArr.length; i++) {
            if (Math.abs(dArr2[i]) > 1.0E-8d) {
                double d2 = 0.0d;
                int[] iArr2 = iArr[i];
                double[] dArr7 = dArr[i];
                for (int i2 = 0; i2 < iArr2.length; i2++) {
                    d2 += dArr6[iArr2[i2]] * dArr7[i2];
                }
                double d3 = dArr4[i] + (d2 / dArr2[i]);
                double d4 = d / dArr2[i];
                double d5 = d3 > d4 ? d3 - d4 : d3 < (-d4) ? d3 + d4 : 0.0d;
                double d6 = d5 - dArr4[i];
                dArr4[i] = d5;
                for (int i3 = 0; i3 < iArr2.length; i3++) {
                    int i4 = iArr2[i3];
                    dArr5[i4] = dArr5[i4] + (d6 * dArr7[i3]);
                    dArr6[i4] = OptimUtils.getPseudoResidual(dArr5[i4], dArr3[i4]);
                }
            }
        }
    }

    protected double findMaxLambda(double[][] dArr, double[] dArr2) {
        double fitIntercept = this.fitIntercept ? OptimUtils.fitIntercept(dArr2) : 0.0d;
        double d = 0.0d;
        for (double[] dArr3 : dArr) {
            double abs = Math.abs(VectorUtils.dotProduct(dArr3, dArr2));
            if (abs > d) {
                d = abs;
            }
        }
        double length = d / dArr2.length;
        if (this.fitIntercept) {
            VectorUtils.add(dArr2, fitIntercept);
        }
        return length;
    }

    protected double findMaxLambda(double[][] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        if (this.fitIntercept) {
            OptimUtils.fitIntercept(dArr3, dArr4, dArr2);
        }
        double d = 0.0d;
        for (double[] dArr5 : dArr) {
            double d2 = 0.0d;
            for (int i = 0; i < dArr5.length; i++) {
                d2 += OptimUtils.getPseudoResidual(dArr3[i], dArr2[i]) * dArr5[i];
            }
            double abs = Math.abs(d2);
            if (abs > d) {
                d = abs;
            }
        }
        double length = d / dArr2.length;
        if (this.fitIntercept) {
            Arrays.fill(dArr3, 0.0d);
            OptimUtils.computePseudoResidual(dArr3, dArr2, dArr4);
        }
        return length;
    }

    protected double findMaxLambda(int[][] iArr, double[][] dArr, double[] dArr2) {
        double fitIntercept = this.fitIntercept ? OptimUtils.fitIntercept(dArr2) : 0.0d;
        DenseVector denseVector = new DenseVector(dArr2);
        double d = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            double abs = Math.abs(VectorUtils.dotProduct(new SparseVector(iArr[i], dArr[i]), denseVector));
            if (abs > d) {
                d = abs;
            }
        }
        double length = d / dArr2.length;
        if (this.fitIntercept) {
            VectorUtils.add(dArr2, fitIntercept);
        }
        return length;
    }

    protected double findMaxLambda(int[][] iArr, double[][] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        if (this.fitIntercept) {
            OptimUtils.fitIntercept(dArr3, dArr4, dArr2);
        }
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double d2 = 0.0d;
            int[] iArr2 = iArr[i];
            double[] dArr5 = dArr[i];
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                int i3 = iArr2[i2];
                d2 += OptimUtils.getPseudoResidual(dArr3[i3], dArr2[i3]) * dArr5[i2];
            }
            double abs = Math.abs(d2);
            if (abs > d) {
                d = abs;
            }
        }
        double length = d / dArr2.length;
        if (this.fitIntercept) {
            Arrays.fill(dArr3, 0.0d);
            OptimUtils.computePseudoResidual(dArr3, dArr2, dArr4);
        }
        return length;
    }

    public boolean fitIntercept() {
        return this.fitIntercept;
    }

    public void fitIntercept(boolean z) {
        this.fitIntercept = z;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public double getLambda() {
        return this.lambda;
    }

    public int getMaxNumIters() {
        return this.maxNumIters;
    }

    public int getNumLambdas() {
        return this.numLambdas;
    }

    public Learner.Task getTask() {
        return this.task;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public boolean refit() {
        return this.refit;
    }

    public void refit(boolean z) {
        this.refit = z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    protected GLM refitRegressor(int[] iArr, boolean[] zArr, double[][] dArr, double[] dArr2, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (zArr[i2]) {
                arrayList.add(dArr[i2]);
            }
        }
        if (arrayList.size() == 0) {
            GLM glm = new GLM(0);
            if (this.fitIntercept) {
                glm.intercept[0] = StatUtils.mean(dArr2);
            }
            return glm;
        }
        ?? r0 = new double[arrayList.size()];
        for (int i3 = 0; i3 < r0.length; i3++) {
            r0[i3] = (double[]) arrayList.get(i3);
        }
        int[] iArr2 = new int[r0.length];
        for (int i4 = 0; i4 < iArr2.length; i4++) {
            iArr2[i4] = i4;
        }
        RidgeLearner ridgeLearner = new RidgeLearner();
        ridgeLearner.setVerbose(this.verbose);
        ridgeLearner.setEpsilon(this.epsilon);
        ridgeLearner.fitIntercept(this.fitIntercept);
        GLM buildRegressor = ridgeLearner.buildRegressor(iArr2, r0, dArr2, i, 1.0E-8d);
        double[] dArr3 = new double[iArr.length];
        double[] coefficients = buildRegressor.coefficients(0);
        int i5 = 0;
        for (int i6 = 0; i6 < dArr3.length; i6++) {
            if (zArr[i6]) {
                int i7 = i5;
                i5++;
                dArr3[i6] = coefficients[i7];
            }
        }
        return GLMOptimUtils.getGLM(iArr, dArr3, buildRegressor.intercept(0));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v13, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v8, types: [int[], int[][]] */
    protected GLM refitRegressor(int[] iArr, boolean[] zArr, int[][] iArr2, double[][] dArr, double[] dArr2, int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (zArr[i2]) {
                arrayList.add(iArr2[i2]);
                arrayList2.add(dArr[i2]);
            }
        }
        if (arrayList.size() == 0) {
            GLM glm = new GLM(0);
            if (this.fitIntercept) {
                glm.intercept[0] = StatUtils.mean(dArr2);
            }
            return glm;
        }
        ?? r0 = new int[arrayList.size()];
        for (int i3 = 0; i3 < r0.length; i3++) {
            r0[i3] = (int[]) arrayList.get(i3);
        }
        ?? r02 = new double[arrayList2.size()];
        for (int i4 = 0; i4 < r0.length; i4++) {
            r02[i4] = (double[]) arrayList2.get(i4);
        }
        int[] iArr3 = new int[r0.length];
        for (int i5 = 0; i5 < iArr3.length; i5++) {
            iArr3[i5] = i5;
        }
        RidgeLearner ridgeLearner = new RidgeLearner();
        ridgeLearner.setVerbose(this.verbose);
        ridgeLearner.setEpsilon(this.epsilon);
        ridgeLearner.fitIntercept(this.fitIntercept);
        GLM buildRegressor = ridgeLearner.buildRegressor(iArr3, r0, r02, dArr2, i, 1.0E-8d);
        double[] dArr3 = new double[iArr.length];
        double[] coefficients = buildRegressor.coefficients(0);
        int i6 = 0;
        for (int i7 = 0; i7 < dArr3.length; i7++) {
            if (zArr[i7]) {
                int i8 = i6;
                i6++;
                dArr3[i7] = coefficients[i8];
            }
        }
        return GLMOptimUtils.getGLM(iArr, dArr3, buildRegressor.intercept(0));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [double[], double[][]] */
    protected GLM refitClassifier(int[] iArr, boolean[] zArr, double[][] dArr, double[] dArr2, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (zArr[i2]) {
                arrayList.add(dArr[i2]);
            }
        }
        ?? r0 = new double[arrayList.size()];
        for (int i3 = 0; i3 < r0.length; i3++) {
            r0[i3] = (double[]) arrayList.get(i3);
        }
        int[] iArr2 = new int[r0.length];
        for (int i4 = 0; i4 < iArr2.length; i4++) {
            iArr2[i4] = i4;
        }
        RidgeLearner ridgeLearner = new RidgeLearner();
        ridgeLearner.setVerbose(this.verbose);
        ridgeLearner.setEpsilon(this.epsilon);
        ridgeLearner.fitIntercept(this.fitIntercept);
        GLM buildBinaryClassifier = ridgeLearner.buildBinaryClassifier(iArr2, r0, dArr2, i, 1.0E-8d);
        double[] dArr3 = new double[iArr.length];
        double[] coefficients = buildBinaryClassifier.coefficients(0);
        int i5 = 0;
        for (int i6 = 0; i6 < dArr3.length; i6++) {
            if (zArr[i6]) {
                int i7 = i5;
                i5++;
                dArr3[i6] = coefficients[i7];
            }
        }
        return GLMOptimUtils.getGLM(iArr, dArr3, buildBinaryClassifier.intercept(0));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v6, types: [int[], int[][]] */
    protected GLM refitClassifier(int[] iArr, boolean[] zArr, int[][] iArr2, double[][] dArr, double[] dArr2, int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (zArr[i2]) {
                arrayList.add(iArr2[i2]);
                arrayList2.add(dArr[i2]);
            }
        }
        ?? r0 = new int[arrayList.size()];
        for (int i3 = 0; i3 < r0.length; i3++) {
            r0[i3] = (int[]) arrayList.get(i3);
        }
        ?? r02 = new double[arrayList2.size()];
        for (int i4 = 0; i4 < r0.length; i4++) {
            r02[i4] = (double[]) arrayList2.get(i4);
        }
        int[] iArr3 = new int[r0.length];
        for (int i5 = 0; i5 < iArr3.length; i5++) {
            iArr3[i5] = i5;
        }
        RidgeLearner ridgeLearner = new RidgeLearner();
        ridgeLearner.setVerbose(this.verbose);
        ridgeLearner.setEpsilon(this.epsilon);
        ridgeLearner.fitIntercept(this.fitIntercept);
        GLM buildBinaryClassifier = ridgeLearner.buildBinaryClassifier(iArr3, r0, r02, dArr2, i, 1.0E-8d);
        double[] dArr3 = new double[iArr.length];
        double[] coefficients = buildBinaryClassifier.coefficients(0);
        int i6 = 0;
        for (int i7 = 0; i7 < dArr3.length; i7++) {
            if (zArr[i7]) {
                int i8 = i6;
                i6++;
                dArr3[i7] = coefficients[i8];
            }
        }
        return GLMOptimUtils.getGLM(iArr, dArr3, buildBinaryClassifier.intercept(0));
    }

    public void setEpsilon(double d) {
        this.epsilon = d;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public void setMaxNumIters(int i) {
        this.maxNumIters = i;
    }

    public void setNumLambdas(int i) {
        this.numLambdas = i;
    }

    public void setTask(Learner.Task task) {
        this.task = task;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    static /* synthetic */ int[] $SWITCH_TABLE$mltk$predictor$Learner$Task() {
        int[] iArr = $SWITCH_TABLE$mltk$predictor$Learner$Task;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Learner.Task.valuesCustom().length];
        try {
            iArr2[Learner.Task.CLASSIFICATION.ordinal()] = 1;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Learner.Task.REGRESSION.ordinal()] = 2;
        } catch (NoSuchFieldError unused2) {
        }
        $SWITCH_TABLE$mltk$predictor$Learner$Task = iArr2;
        return iArr2;
    }
}
