package mltk.predictor.function;

import java.util.Iterator;
import java.util.List;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.predictor.Learner;
import mltk.util.tuple.IntDoublePair;
import mltk.util.tuple.Pair;

/* loaded from: input_file:mltk/predictor/function/SquareCutter.class */
public class SquareCutter extends Learner {
    private int attIndex1;
    private int attIndex2;
    private boolean lineSearch;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mltk/predictor/function/SquareCutter$Table.class */
    public static class Table {
        double[][][] resp;
        double[][][] count;

        Table(int i, int i2) {
            this.resp = new double[i][i2][4];
            this.count = new double[i][i2][4];
        }
    }

    public SquareCutter() {
    }

    public SquareCutter(boolean z) {
        this.lineSearch = z;
    }

    public void setAttIndices(int i, int i2) {
        this.attIndex1 = i;
        this.attIndex2 = i2;
    }

    @Override // mltk.predictor.Learner
    public Function2D build(Instances instances) {
        List<Attribute> attributes = instances.getAttributes();
        int i = 0;
        Attribute attribute = attributes.get(this.attIndex1);
        if (attribute.getType() == Attribute.Type.BINNED) {
            i = ((BinnedAttribute) attribute).getNumBins();
        } else if (attribute.getType() == Attribute.Type.NOMINAL) {
            i = ((NominalAttribute) attribute).getCardinality();
        }
        int i2 = 0;
        Attribute attribute2 = attributes.get(this.attIndex2);
        if (attribute.getType() == Attribute.Type.BINNED) {
            i2 = ((BinnedAttribute) attribute2).getNumBins();
        } else if (attribute.getType() == Attribute.Type.NOMINAL) {
            i2 = ((NominalAttribute) attribute2).getCardinality();
        }
        if (i == 1 || i2 == 1) {
            return new Function2D(attribute.getIndex(), attribute2.getIndex(), new double[]{Double.POSITIVE_INFINITY}, new double[]{Double.POSITIVE_INFINITY}, new double[1][1]);
        }
        Histogram2D histogram2D = new Histogram2D(i, i2);
        Iterator<Instance> it = instances.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            int value = (int) next.getValue(attribute);
            int value2 = (int) next.getValue(attribute2);
            double[] dArr = histogram2D.resp[value];
            dArr[value2] = dArr[value2] + (next.getTarget() * next.getWeight());
            double[] dArr2 = histogram2D.count[value];
            dArr2[value2] = dArr2[value2] + next.getWeight();
        }
        Pair<CHistogram, CHistogram> computeCHistogram = histogram2D.computeCHistogram();
        Table table = new Table(i, i2);
        computeTable(histogram2D, computeCHistogram.v1, computeCHistogram.v2, table);
        double d = Double.POSITIVE_INFINITY;
        double[] dArr3 = new double[4];
        int i3 = -1;
        int[] iArr = new int[2];
        int[] iArr2 = new int[2];
        for (int i4 = 0; i4 < i - 1; i4++) {
            findCuts(table, i4, iArr2);
            getPredictor(table, i4, iArr2, dArr3);
            double rss = getRSS(table, i4, iArr2, dArr3);
            if (rss < d) {
                d = rss;
                i3 = i4;
                iArr[0] = iArr2[0];
                iArr[1] = iArr2[1];
            }
        }
        boolean z = false;
        double[] dArr4 = new double[4];
        int[] iArr3 = new int[2];
        int i5 = -1;
        int[] iArr4 = new int[2];
        for (int i6 = 0; i6 < i2 - 1; i6++) {
            findCuts(table, iArr4, i6);
            getPredictor(table, iArr4, i6, dArr4);
            double rss2 = getRSS(table, iArr4, i6, dArr4);
            if (rss2 < d) {
                d = rss2;
                i5 = i6;
                iArr3[0] = iArr4[0];
                iArr3[1] = iArr4[1];
                z = true;
            }
        }
        if (z) {
            getPredictor(table, iArr3, i5, dArr4);
            if (this.lineSearch) {
                lineSearch(instances, attribute2.getIndex(), attribute.getIndex(), i5, iArr3[0], iArr3[1], dArr4);
            }
            return getFunction2D(attribute.getIndex(), attribute2.getIndex(), iArr3, i5, dArr4);
        }
        getPredictor(table, i3, iArr, dArr3);
        if (this.lineSearch) {
            lineSearch(instances, attribute.getIndex(), attribute2.getIndex(), i3, iArr[0], iArr[1], dArr3);
        }
        return getFunction2D(attribute.getIndex(), attribute2.getIndex(), i3, iArr, dArr3);
    }

    protected static void fillTable(Table table, int i, int i2, CHistogram cHistogram, CHistogram cHistogram2) {
        table.resp[i][i2][1] = cHistogram.sum[i] - table.resp[i][i2][0];
        table.resp[i][i2][2] = cHistogram2.sum[i2] - table.resp[i][i2][0];
        table.resp[i][i2][3] = (cHistogram.sum[cHistogram.size() - 1] - cHistogram.sum[i]) - table.resp[i][i2][2];
        table.count[i][i2][1] = cHistogram.count[i] - table.count[i][i2][0];
        table.count[i][i2][2] = cHistogram2.count[i2] - table.count[i][i2][0];
        table.count[i][i2][3] = (cHistogram.count[cHistogram.size() - 1] - cHistogram.count[i]) - table.count[i][i2][2];
    }

    protected static void computeTable(Histogram2D histogram2D, CHistogram cHistogram, CHistogram cHistogram2, Table table) {
        CHistogram cHistogram3;
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < histogram2D.resp[0].length; i++) {
            d += histogram2D.resp[0][i];
            table.resp[0][i][0] = d;
            d2 += histogram2D.count[0][i];
            table.count[0][i][0] = d2;
            cHistogram3 = cHistogram;
            fillTable(table, 0, i, cHistogram3, cHistogram2);
        }
        for (int i2 = 1; i2 < histogram2D.resp.length; i2++) {
            double d3 = 0.0d;
            double d4 = 0.0d;
            for (int i3 = 0; i3 < histogram2D.resp[i2].length; i3++) {
                d4 += histogram2D.resp[i2][i3];
                table.resp[i2][i3][0] = table.resp[i2 - 1][i3][0] + d4;
                d3 += histogram2D.count[i2][i3];
                table.count[i2][i3][0] = table.count[i2 - 1][i3][0] + d3;
                cHistogram3 = cHistogram;
                fillTable(table, i2, i3, cHistogram3, cHistogram2);
            }
        }
    }

    protected static void getPredictor(Table table, int i, int i2, double[] dArr) {
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = table.count[i][i2][i3] == 0.0d ? 0.0d : table.resp[i][i2][i3] / table.count[i][i2][i3];
        }
    }

    protected static double getRSS(Table table, int i, int i2, double[] dArr) {
        double d = 0.0d;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            d += dArr[i3] * dArr[i3] * table.count[i][i2][i3];
        }
        double d2 = 0.0d + d;
        double d3 = 0.0d;
        for (int i4 = 0; i4 < dArr.length; i4++) {
            d3 += dArr[i4] * table.resp[i][i2][i4];
        }
        return d2 - (2.0d * d3);
    }

    protected static void findCuts(Table table, int i, int[] iArr) {
        double d = Double.POSITIVE_INFINITY;
        for (int i2 = 0; i2 < table.resp[i].length - 1; i2++) {
            double d2 = table.resp[i][i2][0];
            double d3 = table.resp[i][i2][1];
            double d4 = (((-d2) * d2) / table.count[i][i2][0]) - ((d3 * d3) / table.count[i][i2][1]);
            if (d4 < d) {
                d = d4;
                iArr[0] = i2;
            }
        }
        double d5 = Double.POSITIVE_INFINITY;
        for (int i3 = 0; i3 < table.resp[i].length - 1; i3++) {
            double d6 = table.resp[i][i3][2];
            double d7 = table.resp[i][i3][3];
            double d8 = (((-d6) * d6) / table.count[i][i3][2]) - ((d7 * d7) / table.count[i][i3][3]);
            if (d8 < d5) {
                d5 = d8;
                iArr[1] = i3;
            }
        }
    }

    protected static void findCuts(Table table, int[] iArr, int i) {
        double d = Double.POSITIVE_INFINITY;
        for (int i2 = 0; i2 < table.resp.length - 1; i2++) {
            double d2 = table.resp[i2][i][0];
            double d3 = table.resp[i2][i][2];
            double d4 = (((-d2) * d2) / table.count[i2][i][0]) - ((d3 * d3) / table.count[i2][i][2]);
            if (d4 < d) {
                d = d4;
                iArr[0] = i2;
            }
        }
        double d5 = Double.POSITIVE_INFINITY;
        for (int i3 = 0; i3 < table.resp.length - 1; i3++) {
            double d6 = table.resp[i3][i][1];
            double d7 = table.resp[i3][i][3];
            double d8 = (((-d6) * d6) / table.count[i3][i][1]) - ((d7 * d7) / table.count[i3][i][3]);
            if (d8 < d5) {
                d5 = d8;
                iArr[1] = i3;
            }
        }
    }

    protected static void findCut(CHistogram cHistogram, IntDoublePair intDoublePair) {
        intDoublePair.v2 = Double.POSITIVE_INFINITY;
        for (int i = 0; i < cHistogram.size() - 1; i++) {
            double d = cHistogram.sum[i];
            double d2 = cHistogram.sum[cHistogram.size() - 1] - d;
            double d3 = cHistogram.count[i];
            double d4 = (((-d) * d) * d3) - ((d2 * d2) * (cHistogram.count[cHistogram.size() - 1] - d3));
            if (d4 < intDoublePair.v2) {
                intDoublePair.v2 = d4;
                intDoublePair.v1 = i;
            }
        }
    }

    protected static void getPredictor(Table table, int i, int[] iArr, double[] dArr) {
        int i2 = iArr[0];
        int i3 = iArr[1];
        dArr[0] = table.count[i][i2][0] == 0.0d ? 0.0d : table.resp[i][i2][0] / table.count[i][i2][0];
        dArr[1] = table.count[i][i2][1] == 0.0d ? 0.0d : table.resp[i][i2][1] / table.count[i][i2][1];
        dArr[2] = table.count[i][i3][2] == 0.0d ? 0.0d : table.resp[i][i3][2] / table.count[i][i3][2];
        dArr[3] = table.count[i][i3][3] == 0.0d ? 0.0d : table.resp[i][i3][3] / table.count[i][i3][3];
    }

    protected static void getPredictor(Table table, int[] iArr, int i, double[] dArr) {
        int i2 = iArr[0];
        int i3 = iArr[1];
        dArr[0] = table.count[i2][i][0] == 0.0d ? 0.0d : table.resp[i2][i][0] / table.count[i2][i][0];
        dArr[1] = table.count[i2][i][2] == 0.0d ? 0.0d : table.resp[i2][i][2] / table.count[i2][i][2];
        dArr[2] = table.count[i3][i][1] == 0.0d ? 0.0d : table.resp[i3][i][1] / table.count[i3][i][1];
        dArr[3] = table.count[i3][i][3] == 0.0d ? 0.0d : table.resp[i3][i][3] / table.count[i3][i][3];
    }

    protected static double getRSS(Table table, int i, int[] iArr, double[] dArr) {
        int i2 = iArr[0];
        int i3 = iArr[1];
        return ((((0.0d + ((dArr[0] * dArr[0]) * table.count[i][i2][0])) + ((dArr[1] * dArr[1]) * table.count[i][i2][1])) + ((dArr[2] * dArr[2]) * table.count[i][i3][2])) + ((dArr[3] * dArr[3]) * table.count[i][i3][3])) - (2.0d * ((((0.0d + (dArr[0] * table.resp[i][i2][0])) + (dArr[1] * table.resp[i][i2][1])) + (dArr[2] * table.resp[i][i3][2])) + (dArr[3] * table.resp[i][i3][3])));
    }

    protected static double getRSS(Table table, int[] iArr, int i, double[] dArr) {
        int i2 = iArr[0];
        int i3 = iArr[1];
        return ((((0.0d + ((dArr[0] * dArr[0]) * table.count[i2][i][0])) + ((dArr[1] * dArr[1]) * table.count[i2][i][2])) + ((dArr[2] * dArr[2]) * table.count[i3][i][1])) + ((dArr[3] * dArr[3]) * table.count[i3][i][3])) - (2.0d * ((((0.0d + (dArr[0] * table.resp[i2][i][0])) + (dArr[1] * table.resp[i2][i][2])) + (dArr[2] * table.resp[i3][i][1])) + (dArr[3] * table.resp[i3][i][3])));
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected static Function2D getFunction2D(int i, int i2, int i3, int[] iArr, double[] dArr) {
        double[] dArr2;
        double[][] dArr3;
        double[] dArr4 = {i3, Double.POSITIVE_INFINITY};
        if (iArr[0] < iArr[1]) {
            dArr2 = new double[]{iArr[0], iArr[1], Double.POSITIVE_INFINITY};
            dArr3 = new double[]{new double[]{dArr[0], dArr[1], dArr[1]}, new double[]{dArr[2], dArr[2], dArr[3]}};
        } else if (iArr[0] > iArr[1]) {
            dArr2 = new double[]{iArr[1], iArr[0], Double.POSITIVE_INFINITY};
            dArr3 = new double[]{new double[]{dArr[0], dArr[0], dArr[1]}, new double[]{dArr[2], dArr[3], dArr[3]}};
        } else {
            dArr2 = new double[]{iArr[0], Double.POSITIVE_INFINITY};
            dArr3 = new double[]{new double[]{dArr[0], dArr[1]}, new double[]{dArr[2], dArr[3]}};
        }
        return new Function2D(i, i2, dArr4, dArr2, dArr3);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected static Function2D getFunction2D(int i, int i2, int[] iArr, int i3, double[] dArr) {
        double[] dArr2;
        double[][] dArr3;
        double[] dArr4 = {i3, Double.POSITIVE_INFINITY};
        if (iArr[0] < iArr[1]) {
            dArr2 = new double[]{iArr[0], iArr[1], Double.POSITIVE_INFINITY};
            dArr3 = new double[]{new double[]{dArr[0], dArr[2]}, new double[]{dArr[1], dArr[2]}, new double[]{dArr[1], dArr[3]}};
        } else if (iArr[0] > iArr[1]) {
            dArr2 = new double[]{iArr[1], iArr[0], Double.POSITIVE_INFINITY};
            dArr3 = new double[]{new double[]{dArr[0], dArr[2]}, new double[]{dArr[0], dArr[3]}, new double[]{dArr[1], dArr[3]}};
        } else {
            dArr2 = new double[]{iArr[0], Double.POSITIVE_INFINITY};
            dArr3 = new double[]{new double[]{dArr[0], dArr[2]}, new double[]{dArr[1], dArr[3]}};
        }
        return new Function2D(i, i2, dArr2, dArr4, dArr3);
    }

    protected static void lineSearch(Instances instances, int i, int i2, int i3, int i4, int i5, double[] dArr) {
        double[] dArr2 = new double[4];
        double[] dArr3 = new double[4];
        Iterator<Instance> it = instances.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            int value = (int) next.getValue(i);
            int value2 = (int) next.getValue(i2);
            double target = next.getTarget();
            double abs = Math.abs(target);
            if (value <= i3) {
                if (value2 <= i4) {
                    dArr2[0] = dArr2[0] + (target * next.getWeight());
                    dArr3[0] = dArr3[0] + (abs * (1.0d - abs) * next.getWeight());
                } else {
                    dArr2[1] = dArr2[1] + (target * next.getWeight());
                    dArr3[1] = dArr3[1] + (abs * (1.0d - abs) * next.getWeight());
                }
            } else if (value2 <= i5) {
                dArr2[2] = dArr2[2] + (target * next.getWeight());
                dArr3[2] = dArr3[2] + (abs * (1.0d - abs) * next.getWeight());
            } else {
                dArr2[3] = dArr2[3] + (target * next.getWeight());
                dArr3[3] = dArr3[3] + (abs * (1.0d - abs) * next.getWeight());
            }
        }
        for (int i6 = 0; i6 < dArr2.length; i6++) {
            dArr[i6] = Math.abs(dArr3[i6]) < 1.0E-8d ? 0.0d : dArr2[i6] / dArr3[i6];
        }
    }
}
