package mltk.predictor.function;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.predictor.Learner;
import mltk.util.Element;
import mltk.util.Random;
import mltk.util.tuple.DoublePair;

/* loaded from: input_file:mltk/predictor/function/LineCutter.class */
public class LineCutter extends Learner {
    private int attIndex;
    private int numIntervals;
    private boolean lineSearch;
    private boolean leafLimited;
    private double alpha;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mltk/predictor/function/LineCutter$Interval.class */
    public static class Interval implements Comparable<Interval> {
        boolean finalized;
        int start;
        int end;
        double split = Double.POSITIVE_INFINITY;
        double weight;
        double sum;
        double value;
        double gain;
        Interval left;
        Interval right;

        Interval() {
        }

        Interval(int i, int i2, double d, double d2) {
            this.start = i;
            this.end = i2;
            this.weight = d;
            this.sum = d2;
        }

        @Override // java.lang.Comparable
        public int compareTo(Interval interval) {
            if (this.value < interval.value) {
                return -1;
            }
            return this.value > interval.value ? 1 : 0;
        }

        double getPrediction() {
            return this.sum / this.weight;
        }

        boolean isFinalized() {
            return this.finalized;
        }

        boolean isInteriorNode() {
            return this.split < Double.POSITIVE_INFINITY;
        }

        boolean isLeaf() {
            return Double.isNaN(this.split);
        }
    }

    protected static void build(Function1D function1D, List<Double> list, List<DoublePair> list2, double d) {
        if (list.size() == 1) {
            function1D.splits = new double[]{Double.POSITIVE_INFINITY};
            DoublePair doublePair = list2.get(0);
            function1D.predictions = new double[]{doublePair.v2 / doublePair.v1};
            return;
        }
        DoublePair sumUp = sumUp(list2, 0, list2.size());
        Interval interval = new Interval(0, list2.size(), sumUp.v1, sumUp.v2);
        split(list, list2, interval, d);
        PriorityQueue priorityQueue = new PriorityQueue();
        if (!interval.isLeaf()) {
            priorityQueue.add(interval);
        }
        int i = 0;
        while (!priorityQueue.isEmpty()) {
            Interval interval2 = (Interval) priorityQueue.remove();
            interval2.finalized = true;
            split(list, list2, interval2.left, d);
            split(list, list2, interval2.right, d);
            if (!interval2.left.isLeaf()) {
                priorityQueue.add(interval2.left);
            }
            if (!interval2.right.isLeaf()) {
                priorityQueue.add(interval2.right);
            }
            i++;
        }
        ArrayList arrayList = new ArrayList(i);
        ArrayList arrayList2 = new ArrayList(i + 1);
        inorder(interval, arrayList, arrayList2);
        function1D.splits = new double[arrayList2.size()];
        function1D.predictions = new double[arrayList2.size()];
        for (int i2 = 0; i2 < function1D.predictions.length; i2++) {
            function1D.predictions[i2] = ((Double) arrayList2.get(i2)).doubleValue();
        }
        for (int i3 = 0; i3 < function1D.splits.length - 1; i3++) {
            function1D.splits[i3] = ((Double) arrayList.get(i3)).doubleValue();
        }
        function1D.splits[function1D.splits.length - 1] = Double.POSITIVE_INFINITY;
    }

    protected static void build(Function1D function1D, List<Double> list, List<DoublePair> list2, int i) {
        if (list.size() == 1) {
            function1D.splits = new double[]{Double.POSITIVE_INFINITY};
            DoublePair doublePair = list2.get(0);
            function1D.predictions = new double[]{doublePair.v2 / doublePair.v1};
            return;
        }
        DoublePair sumUp = sumUp(list2, 0, list2.size());
        Interval interval = new Interval(0, list2.size(), sumUp.v1, sumUp.v2);
        split(list, list2, interval);
        if (i == 2) {
            function1D.splits = new double[]{interval.split, Double.POSITIVE_INFINITY};
            function1D.predictions = new double[]{interval.left.getPrediction(), interval.right.getPrediction()};
            return;
        }
        if (i > 2) {
            PriorityQueue priorityQueue = new PriorityQueue();
            if (!interval.isLeaf()) {
                priorityQueue.add(interval);
            }
            int i2 = 0;
            while (!priorityQueue.isEmpty()) {
                Interval interval2 = (Interval) priorityQueue.remove();
                interval2.finalized = true;
                split(list, list2, interval2.left);
                split(list, list2, interval2.right);
                if (!interval2.left.isLeaf()) {
                    priorityQueue.add(interval2.left);
                }
                if (!interval2.right.isLeaf()) {
                    priorityQueue.add(interval2.right);
                }
                i2++;
                if (i2 >= i - 1) {
                    break;
                }
            }
            ArrayList arrayList = new ArrayList(i - 1);
            ArrayList arrayList2 = new ArrayList(i);
            inorder(interval, arrayList, arrayList2);
            function1D.splits = new double[arrayList2.size()];
            function1D.predictions = new double[arrayList2.size()];
            for (int i3 = 0; i3 < function1D.predictions.length; i3++) {
                function1D.predictions[i3] = ((Double) arrayList2.get(i3)).doubleValue();
            }
            for (int i4 = 0; i4 < function1D.splits.length - 1; i4++) {
                function1D.splits[i4] = ((Double) arrayList.get(i4)).doubleValue();
            }
            function1D.splits[function1D.splits.length - 1] = Double.POSITIVE_INFINITY;
        }
    }

    protected static void getStats(List<Element<DoublePair>> list, List<Double> list2, List<DoublePair> list3) {
        boolean z;
        if (list.size() == 0) {
            return;
        }
        double d = list.get(0).weight;
        double d2 = list.get(0).element.v2;
        double d3 = list.get(0).element.v1 * d2;
        double d4 = list.get(0).element.v1;
        boolean z2 = true;
        for (int i = 1; i < list.size(); i++) {
            Element<DoublePair> element = list.get(i);
            double d5 = element.weight;
            double d6 = element.element.v2;
            double d7 = element.element.v1;
            if (d5 != d) {
                list2.add(Double.valueOf(d));
                list3.add(new DoublePair(d2, d3));
                d = d5;
                d2 = d6;
                d3 = d7 * d6;
                d4 = d7;
                z = true;
            } else {
                d2 += d6;
                d3 += d6 * d7;
                z = z2 && d4 == d7;
            }
            z2 = z;
        }
        list2.add(Double.valueOf(d));
        list3.add(new DoublePair(d2, d3));
    }

    protected static void inorder(Interval interval, List<Double> list, List<Double> list2) {
        if (!interval.isFinalized()) {
            list2.add(Double.valueOf(interval.getPrediction()));
            return;
        }
        inorder(interval.left, list, list2);
        list.add(Double.valueOf(interval.split));
        inorder(interval.right, list, list2);
    }

    public static void lineSearch(Instances instances, Function1D function1D) {
        double[] predictions = function1D.getPredictions();
        double[] dArr = new double[predictions.length];
        double[] dArr2 = new double[dArr.length];
        Iterator<Instance> it = instances.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            int segmentIndex = function1D.getSegmentIndex(next);
            dArr[segmentIndex] = dArr[segmentIndex] + (next.getTarget() * next.getWeight());
            double abs = Math.abs(next.getTarget());
            dArr2[segmentIndex] = dArr2[segmentIndex] + (abs * (1.0d - abs) * next.getWeight());
        }
        for (int i = 0; i < predictions.length; i++) {
            predictions[i] = dArr2[i] == 0.0d ? 0.0d : dArr[i] / dArr2[i];
        }
    }

    protected static void split(List<Double> list, List<DoublePair> list2, Interval interval) {
        split(list, list2, interval, 5.0d);
    }

    protected static void split(List<Double> list, List<DoublePair> list2, Interval interval, double d) {
        if (interval.weight <= d || interval.end - interval.start <= 1) {
            interval.split = Double.NaN;
            return;
        }
        interval.left = new Interval();
        interval.right = new Interval();
        Interval interval2 = interval.left;
        int i = interval.start;
        interval2.start = i;
        Interval interval3 = interval.right;
        int i2 = interval.end;
        interval3.end = i2;
        double d2 = interval.weight;
        double d3 = interval.sum;
        double d4 = list2.get(i).v1;
        double d5 = d2 - d4;
        double d6 = list2.get(i).v2;
        double d7 = d3 - d6;
        double d8 = (((-d6) * d6) / d4) - ((d7 * d7) / d5);
        ArrayList arrayList = new ArrayList();
        arrayList.add(new double[]{(list.get(i).doubleValue() + list.get(i + 1).doubleValue()) / 2.0d, i, d4, d6, d5, d7});
        for (int i3 = i + 1; i3 < i2 - 1; i3++) {
            double d9 = list2.get(i3).v1;
            double d10 = list2.get(i3).v2;
            d4 += d9;
            d5 -= d9;
            d6 += d10;
            d7 -= d10;
            double d11 = (((-d6) * d6) / d4) - ((d7 * d7) / d5);
            if (d11 <= d8) {
                double doubleValue = (list.get(i3).doubleValue() + list.get(i3 + 1).doubleValue()) / 2.0d;
                if (d11 < d8) {
                    d8 = d11;
                    arrayList.clear();
                }
                arrayList.add(new double[]{doubleValue, i3, d4, d6, d5, d7});
            }
        }
        double[] dArr = (double[]) arrayList.get(Random.getInstance().nextInt(arrayList.size()));
        interval.split = dArr[0];
        interval.left.end = ((int) dArr[1]) + 1;
        interval.right.start = ((int) dArr[1]) + 1;
        interval.left.weight = dArr[2];
        interval.left.sum = dArr[3];
        interval.right.weight = dArr[4];
        interval.right.sum = dArr[5];
        interval.gain = ((dArr[3] / dArr[2]) * dArr[3]) + ((dArr[5] / dArr[4]) * dArr[5]);
        interval.value = (-interval.gain) + ((d3 / d2) * d3);
    }

    protected static DoublePair sumUp(List<DoublePair> list, int i, int i2) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i3 = i; i3 < i2; i3++) {
            DoublePair doublePair = list.get(i3);
            d += doublePair.v1;
            d2 += doublePair.v2;
        }
        return new DoublePair(d, d2);
    }

    public LineCutter() {
        this(false);
    }

    public LineCutter(boolean z) {
        this.attIndex = -1;
        this.lineSearch = z;
        this.leafLimited = true;
    }

    @Override // mltk.predictor.Learner
    public Function1D build(Instances instances) {
        return this.leafLimited ? build(instances, this.attIndex, this.numIntervals) : build(instances, this.attIndex, this.alpha);
    }

    public Function1D build(Instances instances, Attribute attribute, double d) {
        Function1D function1D = new Function1D();
        function1D.attIndex = attribute.getIndex();
        double size = d * instances.size();
        if (attribute.getType() == Attribute.Type.NUMERIC) {
            ArrayList arrayList = new ArrayList(instances.size());
            Iterator<Instance> it = instances.iterator();
            while (it.hasNext()) {
                Instance next = it.next();
                arrayList.add(new Element(new DoublePair(next.getTarget(), next.getWeight()), next.getValue(function1D.attIndex)));
            }
            Collections.sort(arrayList);
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            getStats(arrayList, arrayList2, arrayList3);
            build(function1D, arrayList2, arrayList3, size);
        } else if (attribute.getType() == Attribute.Type.BINNED) {
            DoublePair[] doublePairArr = new DoublePair[((BinnedAttribute) attribute).getNumBins()];
            for (int i = 0; i < doublePairArr.length; i++) {
                doublePairArr[i] = new DoublePair(0.0d, 0.0d);
            }
            Iterator<Instance> it2 = instances.iterator();
            while (it2.hasNext()) {
                Instance next2 = it2.next();
                int value = (int) next2.getValue(function1D.attIndex);
                doublePairArr[value].v2 += next2.getTarget() * next2.getWeight();
                doublePairArr[value].v1 += next2.getWeight();
            }
            ArrayList arrayList4 = new ArrayList(doublePairArr.length);
            ArrayList arrayList5 = new ArrayList(doublePairArr.length);
            for (int i2 = 0; i2 < doublePairArr.length; i2++) {
                if (doublePairArr[i2].v1 != 0.0d) {
                    arrayList5.add(doublePairArr[i2]);
                    arrayList4.add(Double.valueOf(i2));
                }
            }
            build(function1D, arrayList4, arrayList5, size);
        } else {
            DoublePair[] doublePairArr2 = new DoublePair[((NominalAttribute) attribute).getStates().length];
            for (int i3 = 0; i3 < doublePairArr2.length; i3++) {
                doublePairArr2[i3] = new DoublePair(0.0d, 0.0d);
            }
            Iterator<Instance> it3 = instances.iterator();
            while (it3.hasNext()) {
                Instance next3 = it3.next();
                int value2 = (int) next3.getValue(function1D.attIndex);
                doublePairArr2[value2].v2 += next3.getTarget() * next3.getWeight();
                doublePairArr2[value2].v1 += next3.getWeight();
            }
            ArrayList arrayList6 = new ArrayList(doublePairArr2.length);
            ArrayList arrayList7 = new ArrayList(doublePairArr2.length);
            for (int i4 = 0; i4 < doublePairArr2.length; i4++) {
                if (doublePairArr2[i4].v1 != 0.0d) {
                    arrayList7.add(doublePairArr2[i4]);
                    arrayList6.add(Double.valueOf(i4));
                }
            }
            build(function1D, arrayList6, arrayList7, size);
        }
        if (this.lineSearch) {
            lineSearch(instances, function1D);
        }
        return function1D;
    }

    public Function1D build(Instances instances, Attribute attribute, int i) {
        Function1D function1D = new Function1D();
        function1D.attIndex = attribute.getIndex();
        if (attribute.getType() == Attribute.Type.NUMERIC) {
            ArrayList arrayList = new ArrayList(instances.size());
            Iterator<Instance> it = instances.iterator();
            while (it.hasNext()) {
                Instance next = it.next();
                arrayList.add(new Element(new DoublePair(next.getTarget(), next.getWeight()), next.getValue(function1D.attIndex)));
            }
            Collections.sort(arrayList);
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            getStats(arrayList, arrayList2, arrayList3);
            build(function1D, (List<Double>) arrayList2, (List<DoublePair>) arrayList3, i);
        } else if (attribute.getType() == Attribute.Type.BINNED) {
            DoublePair[] doublePairArr = new DoublePair[((BinnedAttribute) attribute).getNumBins()];
            for (int i2 = 0; i2 < doublePairArr.length; i2++) {
                doublePairArr[i2] = new DoublePair(0.0d, 0.0d);
            }
            Iterator<Instance> it2 = instances.iterator();
            while (it2.hasNext()) {
                Instance next2 = it2.next();
                int value = (int) next2.getValue(function1D.attIndex);
                doublePairArr[value].v2 += next2.getTarget() * next2.getWeight();
                doublePairArr[value].v1 += next2.getWeight();
            }
            ArrayList arrayList4 = new ArrayList(doublePairArr.length);
            ArrayList arrayList5 = new ArrayList(doublePairArr.length);
            for (int i3 = 0; i3 < doublePairArr.length; i3++) {
                if (doublePairArr[i3].v1 != 0.0d) {
                    arrayList5.add(doublePairArr[i3]);
                    arrayList4.add(Double.valueOf(i3));
                }
            }
            build(function1D, (List<Double>) arrayList4, (List<DoublePair>) arrayList5, i);
        } else {
            DoublePair[] doublePairArr2 = new DoublePair[((NominalAttribute) attribute).getStates().length];
            for (int i4 = 0; i4 < doublePairArr2.length; i4++) {
                doublePairArr2[i4] = new DoublePair(0.0d, 0.0d);
            }
            Iterator<Instance> it3 = instances.iterator();
            while (it3.hasNext()) {
                Instance next3 = it3.next();
                int value2 = (int) next3.getValue(function1D.attIndex);
                doublePairArr2[value2].v2 += next3.getTarget() * next3.getWeight();
                doublePairArr2[value2].v1 += next3.getWeight();
            }
            ArrayList arrayList6 = new ArrayList(doublePairArr2.length);
            ArrayList arrayList7 = new ArrayList(doublePairArr2.length);
            for (int i5 = 0; i5 < doublePairArr2.length; i5++) {
                if (doublePairArr2[i5].v1 != 0.0d) {
                    arrayList7.add(doublePairArr2[i5]);
                    arrayList6.add(Double.valueOf(i5));
                }
            }
            build(function1D, (List<Double>) arrayList6, (List<DoublePair>) arrayList7, i);
        }
        if (this.lineSearch) {
            lineSearch(instances, function1D);
        }
        return function1D;
    }

    public Function1D build(Instances instances, int i, double d) {
        return build(instances, instances.getAttributes().get(i), d);
    }

    public Function1D build(Instances instances, int i, int i2) {
        return build(instances, instances.getAttributes().get(i), i2);
    }

    public boolean doLineSearch() {
        return this.lineSearch;
    }

    public int getAttributeIndex() {
        return this.attIndex;
    }

    public int getNumIntervals() {
        return this.numIntervals;
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    public void setAttributeIndex(int i) {
        this.attIndex = i;
    }

    public void setLeafLimited(boolean z) {
        this.leafLimited = z;
    }

    public void setLineSearch(boolean z) {
        this.lineSearch = z;
    }

    public void setNumIntervals(int i) {
        this.numIntervals = i;
    }
}
