package mltk.predictor.tree.ensemble.ag;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.io.InstancesReader;
import mltk.predictor.Bagging;
import mltk.predictor.Learner;
import mltk.predictor.evaluation.AUC;
import mltk.predictor.evaluation.Metric;
import mltk.predictor.evaluation.RMSE;
import mltk.predictor.io.PredictorWriter;
import mltk.predictor.tree.RegressionTree;
import mltk.predictor.tree.RegressionTreeLearner;
import mltk.util.Random;
import mltk.util.tuple.IntPair;
import weka.core.TestInstances;

/* loaded from: input_file:mltk/predictor/tree/ensemble/ag/AdditiveGrovesLearner.class */
public class AdditiveGrovesLearner extends Learner {
    private int bestNumTrees;
    private int bestBaggingIters;
    private double bestAlpha;
    private boolean verbose = false;
    private int numTrees = 6;
    private int baggingIters = 60;
    private double minAlpha = 0.01d;
    private Metric metric = new RMSE();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mltk/predictor/tree/ensemble/ag/AdditiveGrovesLearner$ModelMatrix.class */
    public class ModelMatrix {
        AdditiveGroves[][] groves;

        ModelMatrix(int i, int i2) {
            this.groves = new AdditiveGroves[i][i2];
        }

        void expand(int i, int i2, int i3) {
            AdditiveGroves[][] additiveGrovesArr = new AdditiveGroves[i][i2];
            for (int i4 = 0; i4 < this.groves.length; i4++) {
                for (int i5 = 0; i5 < this.groves[i4].length; i5++) {
                    additiveGrovesArr[i4][i5] = this.groves[i4][i5];
                }
            }
            this.groves = additiveGrovesArr;
        }

        void add(int i, int i2, RegressionTree[] regressionTreeArr) {
            if (this.groves[i][i2] == null) {
                this.groves[i][i2] = new AdditiveGroves();
            }
            this.groves[i][i2].groves.add(regressionTreeArr);
        }
    }

    /* loaded from: input_file:mltk/predictor/tree/ensemble/ag/AdditiveGrovesLearner$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path")
        String attPath = null;

        @Argument(name = "-t", description = "train set path", required = true)
        String trainPath = null;

        @Argument(name = "-v", description = "valid set path", required = true)
        String validPath = null;

        @Argument(name = "-o", description = "output model path")
        String outputModelPath = null;

        @Argument(name = "-e", description = "AUC (a), RMSE (r) (default: r)")
        String metric = null;

        @Argument(name = "-b", description = "bagging iterations (default: 60)")
        int baggingIters = 60;

        @Argument(name = "-n", description = "number of trees in a grove (default: 6)")
        int n = 6;

        @Argument(name = "-a", description = "minmum alpha (default: 0.01)")
        double a = 0.01d;

        @Argument(name = "-s", description = "seed of the random number generator (default: 0)")
        long seed = 0;

        Options() {
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mltk/predictor/tree/ensemble/ag/AdditiveGrovesLearner$PerformanceMatrix.class */
    public class PerformanceMatrix {
        Metric metric;
        double[][][] perf;

        PerformanceMatrix(int i, int i2, int i3, Metric metric) {
            this.perf = new double[i][i2][i3];
            this.metric = metric;
        }

        void expand(int i, int i2, int i3) {
            double[][][] dArr = new double[i][i2][i3];
            for (int i4 = 0; i4 < this.perf.length; i4++) {
                double[][] dArr2 = this.perf[i4];
                double[][] dArr3 = dArr[i4];
                for (int i5 = 0; i5 < dArr2.length; i5++) {
                    double[] dArr4 = dArr2[i5];
                    double[] dArr5 = dArr3[i5];
                    for (int i6 = 0; i6 < dArr4.length; i6++) {
                        dArr5[i6] = dArr4[i6];
                    }
                }
            }
            this.perf = dArr;
        }

        void eval(int i, int i2, int i3, double[] dArr, double[] dArr2) {
            this.perf[i][i2][i3] = this.metric.eval(dArr, dArr2);
        }

        IntPair getBestParameters() {
            int i = 0;
            int i2 = 0;
            double worstValue = this.metric.worstValue();
            if (AdditiveGrovesLearner.this.verbose) {
                System.out.println("Perf Matrix:");
                for (int i3 = 0; i3 < this.perf.length; i3++) {
                    double[][] dArr = this.perf[i3];
                    for (int i4 = 0; i4 < dArr.length; i4++) {
                        double[] dArr2 = dArr[i4];
                        double d = dArr2[dArr2.length - 1];
                        System.out.print(String.valueOf(d) + TestInstances.DEFAULT_SEPARATORS);
                        if (this.metric.isFirstBetter(d, worstValue)) {
                            i = i3;
                            i2 = i4;
                            worstValue = d;
                        }
                    }
                    System.out.println();
                }
                System.out.println("Best perf on validation set = " + worstValue);
            }
            return new IntPair(i, i2);
        }

        boolean analyzeBagging(int i, int i2) {
            return Bagging.analyzeBagging(this.perf[i][i2], this.metric);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mltk/predictor/tree/ensemble/ag/AdditiveGrovesLearner$PredictionMatrix.class */
    public class PredictionMatrix {
        double[][][] sumPrediction;
        int n;

        PredictionMatrix(int i, int i2, int i3) {
            this.sumPrediction = new double[i][i2][i3];
            this.n = i3;
        }

        void expand(int i, int i2) {
            double[][][] dArr = new double[i][i2][this.n];
            for (int i3 = 0; i3 < this.sumPrediction.length; i3++) {
                double[][] dArr2 = this.sumPrediction[i3];
                double[][] dArr3 = dArr[i3];
                for (int i4 = 0; i4 < dArr2.length; i4++) {
                    System.arraycopy(dArr2[i4], 0, dArr3[i4], 0, this.n);
                }
            }
            this.sumPrediction = dArr;
        }
    }

    public Metric getMetric() {
        return this.metric;
    }

    public void setMetric(Metric metric) {
        if ((metric instanceof RMSE) || (metric instanceof AUC)) {
            this.metric = metric;
        }
    }

    public int getBestNumTrees() {
        return this.bestNumTrees;
    }

    public int getBestBaggingIters() {
        return this.bestBaggingIters;
    }

    public double getBestAlpha() {
        return this.bestAlpha;
    }

    public int getNumTrees() {
        return this.numTrees;
    }

    public void setNumTrees(int i) {
        this.numTrees = i;
    }

    public double getMinAlpha() {
        return this.minAlpha;
    }

    public void setMinAlpha(double d) {
        this.minAlpha = d;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    public int getBaggingIters() {
        return this.baggingIters;
    }

    public void setBaggingIters(int i) {
        this.baggingIters = i;
    }

    public AdditiveGroves buildRegressor(Instances instances, Instances instances2) {
        IntPair bestParameters;
        boolean z;
        int i = this.baggingIters;
        int i2 = this.numTrees;
        int i3 = 6;
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < 6; i4++) {
            arrayList.add(Double.valueOf(getAlpha(i4)));
        }
        int i5 = 0;
        double[] dArr = new double[instances.size()];
        for (int i6 = 0; i6 < dArr.length; i6++) {
            dArr[i6] = instances.get(i6).getTarget();
        }
        double[] dArr2 = new double[instances2.size()];
        for (int i7 = 0; i7 < dArr2.length; i7++) {
            dArr2[i7] = instances2.get(i7).getTarget();
        }
        PerformanceMatrix performanceMatrix = new PerformanceMatrix(i2, 6, i, this.metric);
        ModelMatrix modelMatrix = new ModelMatrix(i2, 6);
        PredictionMatrix predictionMatrix = new PredictionMatrix(i2, 6, instances2.size());
        do {
            runLayeredTraining(instances, instances2, i5, i, 0, i2, 0, i3, arrayList, performanceMatrix, modelMatrix, predictionMatrix, dArr, dArr2);
            bestParameters = performanceMatrix.getBestParameters();
            z = true;
            i5 = i;
            int i8 = i2;
            int i9 = i3;
            if (bestParameters.v2 == i3 - 1 && arrayList.get(i3 - 1).doubleValue() > 1.0d / instances.size()) {
                z = false;
                i3 += 3;
                for (int i10 = i9; i10 < i3; i10++) {
                    arrayList.add(Double.valueOf(getAlpha(i10)));
                }
                System.out.println(i3);
                predictionMatrix.expand(i2, i3);
                performanceMatrix.expand(i2, i3, i);
                modelMatrix.expand(i2, i3, i);
                runLayeredTraining(instances, instances2, 0, i, 0, i2, i9, i3, arrayList, performanceMatrix, modelMatrix, predictionMatrix, dArr, dArr2);
            }
            if (bestParameters.v1 == i2 - 1) {
                z = false;
                i2 += 3;
                predictionMatrix.expand(i2, i3);
                performanceMatrix.expand(i2, i3, i);
                modelMatrix.expand(i2, i3, i);
                runLayeredTraining(instances, instances2, 0, i, i8, i2, 0, i3, arrayList, performanceMatrix, modelMatrix, predictionMatrix, dArr, dArr2);
            }
            if (!performanceMatrix.analyzeBagging(bestParameters.v1, bestParameters.v2)) {
                z = false;
                i += 40;
                predictionMatrix.expand(i2, i3);
                performanceMatrix.expand(i2, i3, i);
                modelMatrix.expand(i2, i3, i);
            }
        } while (!z);
        for (int i11 = 0; i11 < dArr.length; i11++) {
            instances.get(i11).setTarget(dArr[i11]);
        }
        System.out.println("Best model:");
        System.out.println("Alpha = " + arrayList.get(bestParameters.v2));
        System.out.println("N = " + (bestParameters.v1 + 1));
        System.out.println("b = " + i);
        this.bestBaggingIters = i;
        this.bestNumTrees = bestParameters.v1 + 1;
        this.bestAlpha = getAlpha(bestParameters.v2);
        return modelMatrix.groves[bestParameters.v1][bestParameters.v2];
    }

    public AdditiveGroves runLayeredTraining(Instances instances, int i, int i2, double d) {
        int size = instances.size();
        double[] dArr = new double[instances.size()];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = instances.get(i3).getTarget();
        }
        int alphaIdx = getAlphaIdx(d, instances.size()) + 1;
        AdditiveGroves additiveGroves = new AdditiveGroves();
        for (int i4 = 0; i4 < i; i4++) {
            double[][] dArr2 = new double[i2][size];
            double[] dArr3 = new double[size];
            for (int i5 = 0; i5 < size; i5++) {
                dArr3[i5] = dArr[i5];
            }
            if (this.verbose) {
                System.out.println("Iteration " + (i4 + 1) + " out of " + i);
            }
            for (int i6 = 0; i6 < alphaIdx; i6++) {
                double alpha = getAlpha(i6);
                if (this.verbose) {
                    System.out.println("\tBuilding models with alpha = " + alpha);
                }
                RegressionTree[] regressionTreeArr = new RegressionTree[i2];
                backfit(instances, alpha, regressionTreeArr, dArr2, dArr3);
                additiveGroves.groves.add(regressionTreeArr);
            }
        }
        for (int i7 = 0; i7 < dArr.length; i7++) {
            instances.get(i7).setTarget(dArr[i7]);
        }
        return additiveGroves;
    }

    @Override // mltk.predictor.Learner
    public AdditiveGroves build(Instances instances) {
        Instances instances2 = new Instances(instances.getAttributes(), instances.getTargetAttribute());
        Instances instances3 = new Instances(instances.getAttributes(), instances.getTargetAttribute());
        int size = (instances.size() / 5) * 4;
        for (int i = 0; i < size; i++) {
            instances2.add(instances.get(i));
        }
        for (int i2 = size; i2 < instances.size(); i2++) {
            instances3.add(instances.get(i2));
        }
        return buildRegressor(instances2, instances3);
    }

    protected double getAlpha(int i) {
        double d = 1.0d;
        if (i % 3 == 0) {
            d = 5.0d;
        } else if (i % 3 == 1) {
            d = 2.0d;
        }
        for (int i2 = 0; i2 < (i / 3) + 1; i2++) {
            d /= 10.0d;
        }
        return d;
    }

    protected int getAlphaIdx(double d, int i) {
        int i2 = 0;
        double d2 = 1.0d / i;
        while (d < getAlpha(i2) && d2 < d) {
            i2++;
        }
        return i2;
    }

    protected void backfit(Instances instances, double d, RegressionTree[] regressionTreeArr, double[][] dArr, double[] dArr2) {
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        Bagging.createBootstrapSample(instances, hashMap, arrayList);
        Instances instances2 = new Instances(instances.getAttributes(), instances.getTargetAttribute(), hashMap.size());
        for (Integer num : hashMap.keySet()) {
            int intValue = ((Integer) hashMap.get(num)).intValue();
            Instance m37clone = instances.get(num.intValue()).m37clone();
            m37clone.setWeight(intValue);
            instances2.add(m37clone);
        }
        RegressionTreeLearner regressionTreeLearner = new RegressionTreeLearner();
        regressionTreeLearner.setConstructionMode(RegressionTreeLearner.Mode.ALPHA_LIMITED);
        regressionTreeLearner.setAlpha(d);
        double evalRMSE = evalRMSE(arrayList, dArr2);
        while (true) {
            double d2 = evalRMSE;
            for (int i = 0; i < regressionTreeArr.length; i++) {
                int length = ((i + regressionTreeArr.length) - 1) % regressionTreeArr.length;
                double[] dArr3 = dArr[length];
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + dArr3[i2];
                    instances.get(i2).setTarget(dArr2[i2]);
                }
                RegressionTree build = regressionTreeLearner.build(instances2);
                regressionTreeArr[length] = build;
                for (int i4 = 0; i4 < dArr2.length; i4++) {
                    double regress = build.regress(instances.get(i4));
                    dArr3[i4] = regress;
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] - regress;
                }
            }
            double evalRMSE2 = evalRMSE(arrayList, dArr2);
            if (evalRMSE2 == 0.0d || (d2 - evalRMSE2) / d2 <= 0.002d) {
                return;
            } else {
                evalRMSE = evalRMSE2;
            }
        }
    }

    protected double regress(RegressionTree[] regressionTreeArr, Instance instance) {
        double d = 0.0d;
        for (RegressionTree regressionTree : regressionTreeArr) {
            d += regressionTree.regress(instance);
        }
        return d;
    }

    protected void runLayeredTraining(Instances instances, Instances instances2, int i, int i2, int i3, int i4, int i5, int i6, List<Double> list, PerformanceMatrix performanceMatrix, ModelMatrix modelMatrix, PredictionMatrix predictionMatrix, double[] dArr, double[] dArr2) {
        int size = instances.size();
        int i7 = i4 - i3;
        double[] dArr3 = new double[instances2.size()];
        for (int i8 = i; i8 < i2; i8++) {
            double[][][] dArr4 = new double[i7][i4][size];
            double[][] dArr5 = new double[i7][size];
            for (int i9 = 0; i9 < i7; i9++) {
                double[] dArr6 = dArr5[i9];
                for (int i10 = 0; i10 < size; i10++) {
                    dArr6[i10] = dArr[i10];
                }
            }
            if (i5 != 0) {
                for (int i11 = 0; i11 < i4; i11++) {
                    update(instances, modelMatrix.groves[i11][i5 - 1].groves.get(i8), dArr4, dArr5, i11);
                }
            }
            if (this.verbose) {
                System.out.println("Iteration " + (i8 + 1) + " out of " + i2);
            }
            for (int i12 = i5; i12 < i6; i12++) {
                double doubleValue = list.get(i12).doubleValue();
                if (this.verbose) {
                    System.out.println("\tBuilding models with alpha = " + doubleValue);
                }
                for (int i13 = i3; i13 < i4; i13++) {
                    int i14 = i13 - i3;
                    RegressionTree[] regressionTreeArr = new RegressionTree[i13 + 1];
                    backfit(instances, doubleValue, regressionTreeArr, dArr4[i14], dArr5[i14]);
                    modelMatrix.add(i13, i12, regressionTreeArr);
                    double[] dArr7 = predictionMatrix.sumPrediction[i13][i12];
                    for (int i15 = 0; i15 < dArr7.length; i15++) {
                        int i16 = i15;
                        dArr7[i16] = dArr7[i16] + regress(regressionTreeArr, instances2.get(i15));
                        dArr3[i15] = dArr7[i15] / (i8 + 1);
                    }
                    performanceMatrix.eval(i13, i12, i8, dArr3, dArr2);
                }
            }
        }
    }

    protected double evalRMSE(List<Integer> list, double[] dArr) {
        double d = 0.0d;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            double d2 = dArr[it.next().intValue()];
            d += d2 * d2;
        }
        return Math.sqrt(d / list.size());
    }

    protected void update(Instances instances, RegressionTree[] regressionTreeArr, double[][][] dArr, double[][] dArr2, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            RegressionTree regressionTree = regressionTreeArr[i2];
            for (int i3 = 0; i3 < instances.size(); i3++) {
                double regress = regressionTree.regress(instances.get(i3));
                dArr[i][i2][i3] = regress;
                double[] dArr3 = dArr2[i];
                int i4 = i3;
                dArr3[i4] = dArr3[i4] - regress;
            }
        }
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        CmdLineParser cmdLineParser = new CmdLineParser(AdditiveGrovesLearner.class, options);
        Metric metric = null;
        try {
            cmdLineParser.parse(strArr);
            if ("rmse".startsWith(options.metric)) {
                metric = new RMSE();
            } else if ("auc".startsWith(options.metric)) {
                metric = new AUC();
            }
        } catch (IllegalArgumentException e) {
            cmdLineParser.printUsage();
            System.exit(1);
        }
        Random.getInstance().setSeed(options.seed);
        Instances read = InstancesReader.read(options.attPath, options.trainPath);
        Instances read2 = InstancesReader.read(options.attPath, options.validPath);
        AdditiveGrovesLearner additiveGrovesLearner = new AdditiveGrovesLearner();
        additiveGrovesLearner.setBaggingIters(options.baggingIters);
        additiveGrovesLearner.setNumTrees(options.n);
        additiveGrovesLearner.setMinAlpha(options.a);
        additiveGrovesLearner.setMetric(metric);
        additiveGrovesLearner.setVerbose(true);
        long currentTimeMillis = System.currentTimeMillis();
        AdditiveGroves buildRegressor = additiveGrovesLearner.buildRegressor(read, read2);
        System.out.println("Time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        if (options.outputModelPath != null) {
            PredictorWriter.write(buildRegressor, options.outputModelPath);
        }
    }
}
