package mltk.predictor.tree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import mltk.cmdline.Argument;
import mltk.core.Attribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.SparseVector;
import mltk.predictor.Learner;
import mltk.util.Element;
import mltk.util.Random;
import mltk.util.Stack;
import mltk.util.tuple.DoublePair;
import mltk.util.tuple.IntDoublePair;
import mltk.util.tuple.IntDoublePairComparator;

/* loaded from: input_file:mltk/predictor/tree/RegressionTreeLearner.class */
public class RegressionTreeLearner extends Learner {
    protected int maxDepth;
    protected int maxNumLeaves;
    protected int minLeafSize;
    protected double alpha = 0.01d;
    protected Mode mode = Mode.ALPHA_LIMITED;
    protected static final Double ZERO = new Double(0.0d);
    private static /* synthetic */ int[] $SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:mltk/predictor/tree/RegressionTreeLearner$Dataset.class */
    public static class Dataset {
        public Instances instances;
        public List<List<IntDoublePair>> sortedLists;

        static Dataset create(Instances instances) {
            Dataset dataset = new Dataset(instances);
            List<Attribute> attributes = instances.getAttributes();
            HashMap hashMap = new HashMap();
            for (int i = 0; i < attributes.size(); i++) {
                hashMap.put(Integer.valueOf(attributes.get(i).getIndex()), Integer.valueOf(i));
            }
            for (int i2 = 0; i2 < instances.dimension(); i2++) {
                dataset.sortedLists.add(new ArrayList());
            }
            for (int i3 = 0; i3 < instances.size(); i3++) {
                Instance instance = instances.get(i3);
                dataset.instances.add(instance.m37clone());
                if (instance.isSparse()) {
                    SparseVector sparseVector = (SparseVector) instance.getVector();
                    int[] indices = sparseVector.getIndices();
                    double[] values = sparseVector.getValues();
                    for (int i4 = 0; i4 < indices.length; i4++) {
                        if (hashMap.containsKey(Integer.valueOf(indices[i4]))) {
                            dataset.sortedLists.get(((Integer) hashMap.get(Integer.valueOf(indices[i4]))).intValue()).add(new IntDoublePair(i3, values[i4]));
                        }
                    }
                } else {
                    double[] values2 = instance.getValues();
                    for (int i5 = 0; i5 < values2.length; i5++) {
                        if (hashMap.containsKey(Integer.valueOf(i5)) && values2[i5] != 0.0d) {
                            dataset.sortedLists.get(((Integer) hashMap.get(Integer.valueOf(i5))).intValue()).add(new IntDoublePair(i3, values2[i5]));
                        }
                    }
                }
            }
            IntDoublePairComparator intDoublePairComparator = new IntDoublePairComparator(false);
            Iterator<List<IntDoublePair>> it = dataset.sortedLists.iterator();
            while (it.hasNext()) {
                Collections.sort(it.next(), intDoublePairComparator);
            }
            return dataset;
        }

        Dataset(Instances instances) {
            this.instances = new Instances(instances.getAttributes(), instances.getTargetAttribute());
            this.sortedLists = new ArrayList(instances.dimension());
        }

        void split(RegressionTreeInteriorNode regressionTreeInteriorNode, Dataset dataset, Dataset dataset2) {
            int[] iArr = new int[this.instances.size()];
            int[] iArr2 = new int[this.instances.size()];
            Arrays.fill(iArr, -1);
            Arrays.fill(iArr2, -1);
            for (int i = 0; i < this.instances.size(); i++) {
                Instance instance = this.instances.get(i);
                if (regressionTreeInteriorNode.goLeft(instance)) {
                    dataset.instances.add(instance);
                    iArr[i] = dataset.instances.size() - 1;
                } else {
                    dataset2.instances.add(instance);
                    iArr2[i] = dataset2.instances.size() - 1;
                }
            }
            for (int i2 = 0; i2 < this.sortedLists.size(); i2++) {
                dataset.sortedLists.add(new ArrayList(dataset.instances.size()));
                dataset2.sortedLists.add(new ArrayList(dataset2.instances.size()));
                for (IntDoublePair intDoublePair : this.sortedLists.get(i2)) {
                    int i3 = iArr[intDoublePair.v1];
                    int i4 = iArr2[intDoublePair.v1];
                    if (i3 != -1) {
                        dataset.sortedLists.get(i2).add(new IntDoublePair(i3, intDoublePair.v2));
                    }
                    if (i4 != -1) {
                        dataset2.sortedLists.get(i2).add(new IntDoublePair(i4, intDoublePair.v2));
                    }
                }
            }
        }
    }

    /* loaded from: input_file:mltk/predictor/tree/RegressionTreeLearner$Mode.class */
    public enum Mode {
        DEPTH_LIMITED,
        NUM_LEAVES_LIMITED,
        ALPHA_LIMITED,
        MIN_LEAF_SIZE_LIMITED;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Mode[] valuesCustom() {
            Mode[] valuesCustom = values();
            int length = valuesCustom.length;
            Mode[] modeArr = new Mode[length];
            System.arraycopy(valuesCustom, 0, modeArr, 0, length);
            return modeArr;
        }
    }

    /* loaded from: input_file:mltk/predictor/tree/RegressionTreeLearner$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path", required = true)
        String attPath = null;

        @Argument(name = "-t", description = "train set path", required = true)
        String trainPath = null;

        @Argument(name = "-o", description = "output model path")
        String outputModelPath = null;

        @Argument(name = "-m", description = "construction mode:parameter. Construction mode can be alpha limited (a), depth limited (d), number of leaves limited (l) and minimum leaf size limited (s) (default: a:0.001)")
        String mode = "a:0.001";

        @Argument(name = "-s", description = "seed of the random number generator (default: 0)")
        long seed = 0;

        Options() {
        }
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:9:0x0044. Please report as an issue. */
    /* JADX WARN: Removed duplicated region for block: B:16:0x016f  */
    /* JADX WARN: Removed duplicated region for block: B:18:? A[RETURN, SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static void main(java.lang.String[] r7) throws java.lang.Exception {
        /*
            Method dump skipped, instructions count: 377
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: mltk.predictor.tree.RegressionTreeLearner.main(java.lang.String[]):void");
    }

    @Override // mltk.predictor.Learner
    public RegressionTree build(Instances instances) {
        RegressionTree regressionTree = null;
        switch ($SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode()[this.mode.ordinal()]) {
            case 1:
                regressionTree = buildDepthLimitedTree(instances, this.maxDepth);
                break;
            case 2:
                regressionTree = buildNumLeafLimitedTree(instances, this.maxNumLeaves);
                break;
            case 3:
                regressionTree = buildAlphaLimitedTree(instances, this.alpha);
                break;
            case 4:
                regressionTree = buildMinLeafSizeLimitedTree(instances, this.minLeafSize);
                break;
        }
        return regressionTree;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public Mode getConstructionMode() {
        return this.mode;
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    public int getMaxNumLeaves() {
        return this.maxNumLeaves;
    }

    public int getMinLeafSize() {
        return this.minLeafSize;
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    public void setConstructionMode(Mode mode) {
        this.mode = mode;
    }

    public void setMaxDepth(int i) {
        this.maxDepth = i;
    }

    public void setMaxNumLeaves(int i) {
        this.maxNumLeaves = i;
    }

    public void setMinLeafSize(int i) {
        this.minLeafSize = i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public RegressionTree buildAlphaLimitedTree(Instances instances, double d) {
        return buildMinLeafSizeLimitedTree(instances, (int) (d * instances.size()));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public RegressionTree buildDepthLimitedTree(Instances instances, int i) {
        RegressionTree regressionTree = new RegressionTree();
        double[] dArr = new double[3];
        if (i == 1) {
            getStats(instances, dArr);
            regressionTree.root = new RegressionTreeLeaf(dArr[1]);
            return regressionTree;
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        Dataset create = Dataset.create(instances);
        regressionTree.root = createNode(create, 5, dArr);
        PriorityQueue priorityQueue = new PriorityQueue();
        priorityQueue.add(new Element(regressionTree.root, dArr[2]));
        hashMap.put(regressionTree.root, create);
        hashMap2.put(regressionTree.root, 1);
        while (!priorityQueue.isEmpty()) {
            RegressionTreeNode regressionTreeNode = (RegressionTreeNode) ((Element) priorityQueue.remove()).element;
            Dataset dataset = (Dataset) hashMap.get(regressionTreeNode);
            int intValue = ((Integer) hashMap2.get(regressionTreeNode)).intValue();
            if (!regressionTreeNode.isLeaf()) {
                RegressionTreeInteriorNode regressionTreeInteriorNode = (RegressionTreeInteriorNode) regressionTreeNode;
                Dataset dataset2 = new Dataset(dataset.instances);
                Dataset dataset3 = new Dataset(dataset.instances);
                dataset.split(regressionTreeInteriorNode, dataset2, dataset3);
                if (intValue + 1 == i) {
                    getStats(dataset2.instances, dArr);
                    regressionTreeInteriorNode.left = new RegressionTreeLeaf(dArr[1]);
                    getStats(dataset3.instances, dArr);
                    regressionTreeInteriorNode.right = new RegressionTreeLeaf(dArr[1]);
                } else {
                    regressionTreeInteriorNode.left = createNode(dataset2, 5, dArr);
                    if (!regressionTreeInteriorNode.left.isLeaf()) {
                        priorityQueue.add(new Element(regressionTreeInteriorNode.left, dArr[2]));
                        hashMap.put(regressionTreeInteriorNode.left, dataset2);
                        hashMap2.put(regressionTreeInteriorNode.left, Integer.valueOf(intValue + 1));
                    }
                    regressionTreeInteriorNode.right = createNode(dataset3, 5, dArr);
                    if (!regressionTreeInteriorNode.right.isLeaf()) {
                        priorityQueue.add(new Element(regressionTreeInteriorNode.right, dArr[2]));
                        hashMap.put(regressionTreeInteriorNode.right, dataset3);
                        hashMap2.put(regressionTreeInteriorNode.right, Integer.valueOf(intValue + 1));
                    }
                }
            }
        }
        return regressionTree;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public RegressionTree buildMinLeafSizeLimitedTree(Instances instances, int i) {
        RegressionTree regressionTree = new RegressionTree();
        double[] dArr = new double[3];
        Dataset create = Dataset.create(instances);
        Stack stack = new Stack();
        Stack stack2 = new Stack();
        regressionTree.root = createNode(create, i, dArr);
        stack.push(regressionTree.root);
        stack2.push(create);
        while (!stack.isEmpty()) {
            RegressionTreeNode regressionTreeNode = (RegressionTreeNode) stack.pop();
            Dataset dataset = (Dataset) stack2.pop();
            if (!regressionTreeNode.isLeaf()) {
                RegressionTreeInteriorNode regressionTreeInteriorNode = (RegressionTreeInteriorNode) regressionTreeNode;
                Dataset dataset2 = new Dataset(dataset.instances);
                Dataset dataset3 = new Dataset(dataset.instances);
                dataset.split(regressionTreeInteriorNode, dataset2, dataset3);
                regressionTreeInteriorNode.left = createNode(dataset2, i, dArr);
                regressionTreeInteriorNode.right = createNode(dataset3, i, dArr);
                stack.push(regressionTreeInteriorNode.left);
                stack2.push(dataset2);
                stack.push(regressionTreeInteriorNode.right);
                stack2.push(dataset3);
            }
        }
        return regressionTree;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public RegressionTree buildNumLeafLimitedTree(Instances instances, int i) {
        RegressionTree regressionTree = new RegressionTree();
        double[] dArr = new double[3];
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        Dataset create = Dataset.create(instances);
        PriorityQueue priorityQueue = new PriorityQueue();
        regressionTree.root = createNode(create, 5, dArr);
        priorityQueue.add(new Element(regressionTree.root, dArr[2]));
        hashMap2.put(regressionTree.root, create);
        hashMap.put(regressionTree.root, Double.valueOf(dArr[1]));
        int i2 = 0;
        while (!priorityQueue.isEmpty()) {
            RegressionTreeNode regressionTreeNode = (RegressionTreeNode) ((Element) priorityQueue.remove()).element;
            Dataset dataset = (Dataset) hashMap2.get(regressionTreeNode);
            if (!regressionTreeNode.isLeaf()) {
                RegressionTreeInteriorNode regressionTreeInteriorNode = (RegressionTreeInteriorNode) regressionTreeNode;
                Dataset dataset2 = new Dataset(dataset.instances);
                Dataset dataset3 = new Dataset(dataset.instances);
                dataset.split(regressionTreeInteriorNode, dataset2, dataset3);
                regressionTreeInteriorNode.left = createNode(dataset2, 5, dArr);
                if (regressionTreeInteriorNode.left.isLeaf()) {
                    i2++;
                } else {
                    hashMap.put(regressionTreeInteriorNode.left, Double.valueOf(dArr[1]));
                    priorityQueue.add(new Element(regressionTreeInteriorNode.left, dArr[2]));
                    hashMap2.put(regressionTreeInteriorNode.left, dataset2);
                }
                regressionTreeInteriorNode.right = createNode(dataset3, 5, dArr);
                if (regressionTreeInteriorNode.right.isLeaf()) {
                    i2++;
                } else {
                    hashMap.put(regressionTreeInteriorNode.right, Double.valueOf(dArr[1]));
                    priorityQueue.add(new Element(regressionTreeInteriorNode.right, dArr[2]));
                    hashMap2.put(regressionTreeInteriorNode.right, dataset3);
                }
                if (i2 + priorityQueue.size() >= i) {
                    break;
                }
            }
        }
        HashMap hashMap3 = new HashMap();
        traverse(regressionTree.root, hashMap3);
        while (!priorityQueue.isEmpty()) {
            RegressionTreeNode regressionTreeNode2 = (RegressionTreeNode) ((Element) priorityQueue.remove()).element;
            double doubleValue = ((Double) hashMap.get(regressionTreeNode2)).doubleValue();
            RegressionTreeInteriorNode regressionTreeInteriorNode2 = (RegressionTreeInteriorNode) hashMap3.get(regressionTreeNode2);
            if (regressionTreeInteriorNode2.left == regressionTreeNode2) {
                regressionTreeInteriorNode2.left = new RegressionTreeLeaf(doubleValue);
            } else {
                regressionTreeInteriorNode2.right = new RegressionTreeLeaf(doubleValue);
            }
        }
        return regressionTree;
    }

    protected RegressionTreeNode createNode(Dataset dataset, int i, double[] dArr) {
        boolean stats = getStats(dataset.instances, dArr);
        double d = dArr[0];
        double d2 = dArr[1];
        double d3 = d * d2;
        if (dataset.instances.size() < i || stats) {
            return new RegressionTreeLeaf(d2);
        }
        double d4 = Double.POSITIVE_INFINITY;
        ArrayList arrayList = new ArrayList();
        List<Attribute> attributes = dataset.instances.getAttributes();
        for (int i2 = 0; i2 < attributes.size(); i2++) {
            int index = attributes.get(i2).getIndex();
            List<IntDoublePair> list = dataset.sortedLists.get(i2);
            ArrayList arrayList2 = new ArrayList(list.size());
            ArrayList arrayList3 = new ArrayList(list.size());
            getHistogram(dataset.instances, list, arrayList2, d, d3, arrayList3);
            if (arrayList2.size() > 1) {
                DoublePair split = split(arrayList2, arrayList3, d, d3);
                if (split.v2 <= d4) {
                    IntDoublePair intDoublePair = new IntDoublePair(index, split.v1);
                    if (split.v2 < d4) {
                        arrayList.clear();
                        d4 = split.v2;
                    }
                    arrayList.add(intDoublePair);
                }
            }
        }
        if (d4 >= Double.POSITIVE_INFINITY) {
            return new RegressionTreeLeaf(d2);
        }
        IntDoublePair intDoublePair2 = (IntDoublePair) arrayList.get(Random.getInstance().nextInt(arrayList.size()));
        RegressionTreeInteriorNode regressionTreeInteriorNode = new RegressionTreeInteriorNode(intDoublePair2.v1, intDoublePair2.v2);
        if (dArr.length > 2) {
            dArr[2] = d4 + (d * d2 * d2);
        }
        return regressionTreeInteriorNode;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void getHistogram(Instances instances, List<IntDoublePair> list, List<Double> list2, double d, double d2, List<DoublePair> list3) {
        double d3;
        if (list.size() == 0) {
            return;
        }
        double d4 = list.get(0).v2;
        double weight = instances.get(list.get(0).v1).getWeight();
        double target = instances.get(list.get(0).v1).getTarget() * weight;
        for (int i = 1; i < list.size(); i++) {
            double d5 = list.get(i).v2;
            double weight2 = instances.get(list.get(i).v1).getWeight();
            double target2 = instances.get(list.get(i).v1).getTarget();
            if (d5 != d4) {
                list2.add(Double.valueOf(d4));
                list3.add(new DoublePair(weight, target));
                d4 = d5;
                weight = weight2;
                d3 = target2 * weight2;
            } else {
                weight += weight2;
                d3 = target + (target2 * weight2);
            }
            target = d3;
        }
        list2.add(Double.valueOf(d4));
        list3.add(new DoublePair(weight, target));
        if (list.size() != instances.size()) {
            double d6 = 0.0d;
            double d7 = 0.0d;
            for (DoublePair doublePair : list3) {
                d6 += doublePair.v1;
                d7 += doublePair.v2;
            }
            double d8 = d - d6;
            double d9 = d2 - d7;
            int binarySearch = Collections.binarySearch(list2, ZERO);
            if (binarySearch < 0) {
                list2.add((-binarySearch) - 1, ZERO);
                list3.add((-binarySearch) - 1, new DoublePair(d8, d9));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean getStats(Instances instances, double[] dArr) {
        dArr[1] = 0.0d;
        dArr[0] = 0.0d;
        if (instances.size() == 0) {
            return true;
        }
        double target = instances.get(0).getTarget();
        boolean z = true;
        Iterator<Instance> it = instances.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            double weight = next.getWeight();
            double target2 = next.getTarget();
            dArr[0] = dArr[0] + weight;
            dArr[1] = dArr[1] + (weight * target2);
            if (z && target2 != target) {
                z = false;
            }
        }
        dArr[1] = dArr[1] / dArr[0];
        return z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DoublePair split(List<Double> list, List<DoublePair> list2, double d, double d2) {
        double d3 = list2.get(0).v1;
        double d4 = d - d3;
        double d5 = list2.get(0).v2;
        double d6 = d2 - d5;
        double d7 = (((-d5) * d5) / d3) - ((d6 * d6) / d4);
        ArrayList arrayList = new ArrayList();
        arrayList.add(Double.valueOf((list.get(0).doubleValue() + list.get(1).doubleValue()) / 2.0d));
        for (int i = 1; i < list.size() - 1; i++) {
            double d8 = list2.get(i).v1;
            double d9 = list2.get(i).v2;
            d3 += d8;
            d4 -= d8;
            d5 += d9;
            d6 -= d9;
            double d10 = (((-d5) * d5) / d3) - ((d6 * d6) / d4);
            if (d10 <= d7) {
                double doubleValue = (list.get(i).doubleValue() + list.get(i + 1).doubleValue()) / 2.0d;
                if (d10 < d7) {
                    d7 = d10;
                    arrayList.clear();
                }
                arrayList.add(Double.valueOf(doubleValue));
            }
        }
        return new DoublePair(((Double) arrayList.get(Random.getInstance().nextInt(arrayList.size()))).doubleValue(), d7);
    }

    protected void traverse(RegressionTreeNode regressionTreeNode, Map<RegressionTreeNode, RegressionTreeNode> map) {
        if (regressionTreeNode.isLeaf()) {
            return;
        }
        RegressionTreeInteriorNode regressionTreeInteriorNode = (RegressionTreeInteriorNode) regressionTreeNode;
        if (regressionTreeInteriorNode.left != null) {
            map.put(regressionTreeInteriorNode.left, regressionTreeNode);
            traverse(regressionTreeInteriorNode.left, map);
        }
        if (regressionTreeInteriorNode.right != null) {
            map.put(regressionTreeInteriorNode.right, regressionTreeNode);
            traverse(regressionTreeInteriorNode.right, map);
        }
    }

    static /* synthetic */ int[] $SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode() {
        int[] iArr = $SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Mode.valuesCustom().length];
        try {
            iArr2[Mode.ALPHA_LIMITED.ordinal()] = 3;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Mode.DEPTH_LIMITED.ordinal()] = 1;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[Mode.MIN_LEAF_SIZE_LIMITED.ordinal()] = 4;
        } catch (NoSuchFieldError unused3) {
        }
        try {
            iArr2[Mode.NUM_LEAVES_LIMITED.ordinal()] = 2;
        } catch (NoSuchFieldError unused4) {
        }
        $SWITCH_TABLE$mltk$predictor$tree$RegressionTreeLearner$Mode = iArr2;
        return iArr2;
    }
}
