package mltk.predictor.tree.ensemble.brt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Attribute;
import mltk.core.Instances;
import mltk.core.io.InstancesReader;
import mltk.predictor.Learner;
import mltk.predictor.io.PredictorWriter;
import mltk.predictor.tree.RegressionTree;
import mltk.predictor.tree.RegressionTreeLeaf;
import mltk.predictor.tree.RegressionTreeLearner;
import mltk.util.ArrayUtils;
import mltk.util.MathUtils;
import mltk.util.Permutation;
import mltk.util.Random;
import mltk.util.VectorUtils;

/* loaded from: input_file:mltk/predictor/tree/ensemble/brt/LADBoostLearner.class */
public class LADBoostLearner extends Learner {
    private boolean verbose = false;
    private int maxNumIters = 3500;
    private int maxNumLeaves = 100;
    private double alpha = 1.0d;
    private double learningRate = 1.0d;

    /* loaded from: input_file:mltk/predictor/tree/ensemble/brt/LADBoostLearner$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path")
        String attPath = null;

        @Argument(name = "-t", description = "train set path", required = true)
        String trainPath = null;

        @Argument(name = "-o", description = "output model path")
        String outputModelPath = null;

        @Argument(name = "-c", description = "max number of leaves (default: 100)")
        int maxNumLeaves = 100;

        @Argument(name = "-m", description = "maximum number of iterations", required = true)
        int maxNumIters = -1;

        @Argument(name = "-s", description = "seed of the random number generator (default: 0)")
        long seed = 0;

        @Argument(name = "-l", description = "learning rate (default: 0.01)")
        double learningRate = 0.01d;

        Options() {
        }
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    public int getMaxNumIters() {
        return this.maxNumIters;
    }

    public void setMaxNumIters(int i) {
        this.maxNumIters = i;
    }

    public int getMaxNumLeaves() {
        return this.maxNumLeaves;
    }

    public void setMaxNumLeaves(int i) {
        this.maxNumLeaves = i;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public BRT build(Instances instances, int i, int i2) {
        BRT brt = new BRT(1);
        List<Attribute> attributes = instances.getAttributes();
        int[] iArr = new int[(int) (attributes.size() * this.alpha)];
        Permutation permutation = new Permutation(attributes.size());
        permutation.permute();
        double[] dArr = new double[instances.size()];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = instances.get(i3).getTarget();
        }
        double median = ArrayUtils.getMedian(dArr);
        brt.trees[0].add(new RegressionTree(new RegressionTreeLeaf(median)));
        RegressionTreeLearner regressionTreeLearner = new RegressionTreeLearner();
        regressionTreeLearner.setConstructionMode(RegressionTreeLearner.Mode.NUM_LEAVES_LIMITED);
        regressionTreeLearner.setMaxNumLeaves(i2);
        double[] dArr2 = new double[instances.size()];
        for (int i4 = 0; i4 < dArr2.length; i4++) {
            dArr2[i4] = dArr[i4] - median;
        }
        for (int i5 = 0; i5 < i; i5++) {
            int[] permutation2 = permutation.getPermutation();
            for (int i6 = 0; i6 < iArr.length; i6++) {
                iArr[i6] = permutation2[i6];
            }
            Arrays.sort(iArr);
            instances.setAttributes(instances.getAttributes(iArr));
            for (int i7 = 0; i7 < dArr2.length; i7++) {
                instances.get(i7).setTarget(MathUtils.sign(dArr2[i7]));
            }
            RegressionTree build = regressionTreeLearner.build(instances);
            brt.trees[0].add(build);
            instances.setAttributes(attributes);
            HashMap hashMap = new HashMap();
            for (int i8 = 0; i8 < dArr2.length; i8++) {
                RegressionTreeLeaf leafNode = build.getLeafNode(instances.get(i8));
                if (!hashMap.containsKey(leafNode)) {
                    hashMap.put(leafNode, new ArrayList());
                }
                ((List) hashMap.get(leafNode)).add(Integer.valueOf(i8));
            }
            for (Map.Entry entry : hashMap.entrySet()) {
                RegressionTreeLeaf regressionTreeLeaf = (RegressionTreeLeaf) entry.getKey();
                List list = (List) entry.getValue();
                double[] dArr3 = new double[list.size()];
                for (int i9 = 0; i9 < dArr3.length; i9++) {
                    dArr3[i9] = dArr2[((Integer) list.get(i9)).intValue()];
                }
                regressionTreeLeaf.setPrediction(ArrayUtils.getMedian(dArr3) * this.learningRate);
            }
            for (int i10 = 0; i10 < dArr2.length; i10++) {
                int i11 = i10;
                dArr2[i11] = dArr2[i11] - build.regress(instances.get(i10));
            }
            if (this.verbose) {
                System.out.println("Iteration " + i5 + ": " + (VectorUtils.l1norm(dArr2) / dArr2.length));
            }
        }
        for (int i12 = 0; i12 < dArr.length; i12++) {
            instances.get(i12).setTarget(dArr[i12]);
        }
        return brt;
    }

    @Override // mltk.predictor.Learner
    public BRT build(Instances instances) {
        return build(instances, this.maxNumIters, this.maxNumLeaves);
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        CmdLineParser cmdLineParser = new CmdLineParser(LADBoostLearner.class, options);
        try {
            cmdLineParser.parse(strArr);
        } catch (IllegalArgumentException e) {
            cmdLineParser.printUsage();
            System.exit(1);
        }
        Random.getInstance().setSeed(options.seed);
        Instances read = InstancesReader.read(options.attPath, options.trainPath);
        LADBoostLearner lADBoostLearner = new LADBoostLearner();
        lADBoostLearner.setLearningRate(options.learningRate);
        lADBoostLearner.setMaxNumIters(options.maxNumIters);
        lADBoostLearner.setMaxNumLeaves(options.maxNumLeaves);
        lADBoostLearner.setVerbose(true);
        long currentTimeMillis = System.currentTimeMillis();
        BRT build = lADBoostLearner.build(read);
        System.out.println("Time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        if (options.outputModelPath != null) {
            PredictorWriter.write(build, options.outputModelPath);
        }
    }
}
