package mltk.predictor.tree.ensemble.rf;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import mltk.cmdline.Argument;
import mltk.core.Attribute;
import mltk.core.Instances;
import mltk.predictor.tree.RegressionTree;
import mltk.predictor.tree.RegressionTreeInteriorNode;
import mltk.predictor.tree.RegressionTreeLeaf;
import mltk.predictor.tree.RegressionTreeLearner;
import mltk.predictor.tree.RegressionTreeNode;
import mltk.util.Permutation;
import mltk.util.Random;
import mltk.util.tuple.DoublePair;
import mltk.util.tuple.IntDoublePair;

/* loaded from: input_file:mltk/predictor/tree/ensemble/rf/RandomRegressionTreeLearner.class */
public class RandomRegressionTreeLearner extends RegressionTreeLearner {
    private int numFeatures = -1;
    private Permutation perm;
    private static /* synthetic */ int[] $SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode;

    /* loaded from: input_file:mltk/predictor/tree/ensemble/rf/RandomRegressionTreeLearner$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path", required = true)
        String attPath = null;

        @Argument(name = "-t", description = "train set path", required = true)
        String trainPath = null;

        @Argument(name = "-o", description = "output model path")
        String outputModelPath = null;

        @Argument(name = "-m", description = "construction mode:parameter. Construction mode can be alpha limited (a), depth limited (d), number of leaves limited (l) and minimum leaf size limited (s) (default: a:0.001)")
        String mode = "a:0.001";

        @Argument(name = "-f", description = "number of features to consider")
        int numFeatures = -1;

        @Argument(name = "-b", description = "bagging iterations (default: 100)")
        int baggingIters = 100;

        Options() {
        }
    }

    public RandomRegressionTreeLearner() {
        this.alpha = 0.01d;
        this.mode = RegressionTreeLearner.Mode.ALPHA_LIMITED;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    public void setNumFeatures(int i) {
        this.numFeatures = i;
    }

    @Override // mltk.predictor.tree.RegressionTreeLearner, mltk.predictor.Learner
    public RegressionTree build(Instances instances) {
        if (this.numFeatures <= 0) {
            this.numFeatures = instances.getAttributes().size() / 3;
        }
        if (this.perm == null || this.perm.size() != instances.getAttributes().size()) {
            this.perm = new Permutation(instances.getAttributes().size());
        }
        RegressionTree regressionTree = null;
        switch ($SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode()[this.mode.ordinal()]) {
            case 1:
                regressionTree = buildDepthLimitedTree(instances, this.maxDepth);
                break;
            case 2:
                regressionTree = buildNumLeafLimitedTree(instances, this.maxNumLeaves);
                break;
            case 3:
                regressionTree = buildAlphaLimitedTree(instances, this.alpha);
                break;
            case 4:
                regressionTree = buildMinLeafSizeLimitedTree(instances, this.minLeafSize);
                break;
        }
        return regressionTree;
    }

    @Override // mltk.predictor.tree.RegressionTreeLearner
    protected RegressionTreeNode createNode(RegressionTreeLearner.Dataset dataset, int i, double[] dArr) {
        boolean stats = getStats(dataset.instances, dArr);
        double d = dArr[0];
        double d2 = dArr[1];
        double d3 = d * d2;
        if (dArr[0] < i || stats) {
            return new RegressionTreeLeaf(d2);
        }
        double d4 = Double.POSITIVE_INFINITY;
        ArrayList arrayList = new ArrayList();
        List<Attribute> attributes = dataset.instances.getAttributes();
        int[] permutation = this.perm.permute().getPermutation();
        HashSet hashSet = new HashSet(this.numFeatures);
        for (int i2 = 0; i2 < this.numFeatures; i2++) {
            hashSet.add(Integer.valueOf(permutation[i2]));
        }
        for (int i3 = 0; i3 < attributes.size(); i3++) {
            Attribute attribute = attributes.get(i3);
            if (hashSet.contains(Integer.valueOf(i3))) {
                int index = attribute.getIndex();
                List<IntDoublePair> list = dataset.sortedLists.get(i3);
                ArrayList arrayList2 = new ArrayList(list.size());
                ArrayList arrayList3 = new ArrayList(list.size());
                getHistogram(dataset.instances, list, arrayList2, d, d3, arrayList3);
                if (arrayList2.size() > 1) {
                    DoublePair split = split(arrayList2, arrayList3, d, d3);
                    if (split.v2 <= d4) {
                        IntDoublePair intDoublePair = new IntDoublePair(index, split.v1);
                        if (split.v2 < d4) {
                            arrayList.clear();
                            d4 = split.v2;
                        }
                        arrayList.add(intDoublePair);
                    }
                }
            }
        }
        if (d4 >= Double.POSITIVE_INFINITY) {
            return new RegressionTreeLeaf(d2);
        }
        IntDoublePair intDoublePair2 = (IntDoublePair) arrayList.get(Random.getInstance().nextInt(arrayList.size()));
        RegressionTreeInteriorNode regressionTreeInteriorNode = new RegressionTreeInteriorNode(intDoublePair2.v1, intDoublePair2.v2);
        if (dArr.length > 2) {
            dArr[2] = d4 + (d * d2 * d2);
        }
        return regressionTreeInteriorNode;
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:9:0x0045. Please report as an issue. */
    /* JADX WARN: Removed duplicated region for block: B:16:0x016e  */
    /* JADX WARN: Removed duplicated region for block: B:18:? A[RETURN, SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static void main(java.lang.String[] r7) throws java.lang.Exception {
        /*
            Method dump skipped, instructions count: 376
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: mltk.predictor.tree.ensemble.rf.RandomRegressionTreeLearner.main(java.lang.String[]):void");
    }

    static /* synthetic */ int[] $SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode() {
        int[] iArr = $SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[RegressionTreeLearner.Mode.valuesCustom().length];
        try {
            iArr2[RegressionTreeLearner.Mode.ALPHA_LIMITED.ordinal()] = 3;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[RegressionTreeLearner.Mode.DEPTH_LIMITED.ordinal()] = 1;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[RegressionTreeLearner.Mode.MIN_LEAF_SIZE_LIMITED.ordinal()] = 4;
        } catch (NoSuchFieldError unused3) {
        }
        try {
            iArr2[RegressionTreeLearner.Mode.NUM_LEAVES_LIMITED.ordinal()] = 2;
        } catch (NoSuchFieldError unused4) {
        }
        $SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode = iArr2;
        return iArr2;
    }
}
