package mltk.predictor.gam.interaction;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Discretizer;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.core.io.InstancesReader;
import mltk.predictor.function.CHistogram;
import mltk.predictor.function.Histogram2D;
import mltk.util.Element;
import mltk.util.tuple.DoublePair;
import mltk.util.tuple.IntPair;

/* loaded from: input_file:mltk/predictor/gam/interaction/FAST.class */
public class FAST {
    private static /* synthetic */ int[] $SWITCH_TABLE$mltk$core$Attribute$Type;

    /* loaded from: input_file:mltk/predictor/gam/interaction/FAST$FASTThread.class */
    static class FASTThread extends Thread {
        List<Element<IntPair>> pairs = new ArrayList();
        Instances instances;

        FASTThread(Instances instances) {
            this.instances = instances;
        }

        public void add(Element<IntPair> element) {
            this.pairs.add(element);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            FAST.computeWeights(this.instances, this.pairs);
        }
    }

    /* loaded from: input_file:mltk/predictor/gam/interaction/FAST$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path")
        String attPath = null;

        @Argument(name = "-d", description = "dataset path", required = true)
        String datasetPath = null;

        @Argument(name = "-R", description = "residual path", required = true)
        String residualPath = null;

        @Argument(name = "-o", description = "output path", required = true)
        String outputPath = null;

        @Argument(name = "-b", description = "number of bins (default: 8)")
        int maxNumBins = 8;

        @Argument(name = "-p", description = "number of threads (default: 1)")
        int numThreads = 1;

        Options() {
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:mltk/predictor/gam/interaction/FAST$Table.class */
    public static class Table {
        double[][][] resp;
        double[][][] count;

        Table(int i, int i2) {
            this.resp = new double[i][i2][4];
            this.count = new double[i][i2][4];
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        CmdLineParser cmdLineParser = new CmdLineParser(FAST.class, options);
        try {
            cmdLineParser.parse(strArr);
        } catch (IllegalArgumentException e) {
            cmdLineParser.printUsage();
            System.exit(1);
        }
        Instances read = InstancesReader.read(options.attPath, options.datasetPath);
        System.out.println("Reading residuals...");
        BufferedReader bufferedReader = new BufferedReader(new FileReader(options.residualPath), 65535);
        for (int i = 0; i < read.size(); i++) {
            read.get(i).setTarget(Double.parseDouble(bufferedReader.readLine()));
        }
        bufferedReader.close();
        List<Attribute> attributes = read.getAttributes();
        System.out.println("Discretizing attribute...");
        for (int i2 = 0; i2 < attributes.size(); i2++) {
            if (attributes.get(i2).getType() == Attribute.Type.NUMERIC) {
                Discretizer.discretize(read, i2, options.maxNumBins);
            }
        }
        System.out.println("Generating all pairs of attributes...");
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < attributes.size(); i3++) {
            for (int i4 = i3 + 1; i4 < attributes.size(); i4++) {
                arrayList.add(new Element(new IntPair(i3, i4), 0.0d));
            }
        }
        System.out.println("Creating threads...");
        FASTThread[] fASTThreadArr = new FASTThread[options.numThreads];
        long currentTimeMillis = System.currentTimeMillis();
        for (int i5 = 0; i5 < fASTThreadArr.length; i5++) {
            fASTThreadArr[i5] = new FASTThread(read);
        }
        for (int i6 = 0; i6 < arrayList.size(); i6++) {
            fASTThreadArr[i6 % fASTThreadArr.length].add((Element) arrayList.get(i6));
        }
        for (FASTThread fASTThread : fASTThreadArr) {
            fASTThread.start();
        }
        System.out.println("Running FAST...");
        for (FASTThread fASTThread2 : fASTThreadArr) {
            fASTThread2.join();
        }
        long currentTimeMillis2 = System.currentTimeMillis();
        System.out.println("Sorting pairs...");
        Collections.sort(arrayList);
        System.out.println("Time: " + ((currentTimeMillis2 - currentTimeMillis) / 1000.0d));
        PrintWriter printWriter = new PrintWriter(options.outputPath);
        for (int i7 = 0; i7 < arrayList.size(); i7++) {
            Element element = (Element) arrayList.get(i7);
            printWriter.println(String.valueOf(((IntPair) element.element).v1) + "\t" + ((IntPair) element.element).v2 + "\t" + element.weight);
        }
        printWriter.flush();
        printWriter.close();
    }

    public static void computeWeights(Instances instances, List<Element<IntPair>> list) {
        List<Attribute> attributes = instances.getAttributes();
        boolean[] zArr = new boolean[attributes.size()];
        for (Element<IntPair> element : list) {
            int i = element.element.v1;
            zArr[element.element.v2] = true;
            zArr[i] = true;
        }
        CHistogram[] cHistogramArr = new CHistogram[attributes.size()];
        for (int i2 = 0; i2 < cHistogramArr.length; i2++) {
            if (zArr[i2]) {
                switch ($SWITCH_TABLE$mltk$core$Attribute$Type()[attributes.get(i2).getType().ordinal()]) {
                    case 1:
                        cHistogramArr[i2] = new CHistogram(((NominalAttribute) attributes.get(i2)).getCardinality());
                        break;
                    case 3:
                        cHistogramArr[i2] = new CHistogram(((BinnedAttribute) attributes.get(i2)).getNumBins());
                        break;
                }
            }
        }
        DoublePair computeCHistograms = computeCHistograms(instances, zArr, cHistogramArr);
        for (Element<IntPair> element2 : list) {
            int i3 = element2.element.v1;
            int i4 = element2.element.v2;
            Histogram2D histogram2D = new Histogram2D(cHistogramArr[i3].size(), cHistogramArr[i4].size());
            computeHistogram2D(instances, i3, i4, histogram2D);
            computeWeight(element2, cHistogramArr, histogram2D, computeCHistograms.v1, computeCHistograms.v2);
        }
    }

    protected static DoublePair computeCHistograms(Instances instances, boolean[] zArr, CHistogram[] cHistogramArr) {
        double d = 0.0d;
        double d2 = 0.0d;
        Iterator<Instance> it = instances.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            double target = next.getTarget();
            for (int i = 0; i < instances.getAttributes().size(); i++) {
                if (zArr[i]) {
                    int value = (int) next.getValue(i);
                    double[] dArr = cHistogramArr[i].sum;
                    dArr[value] = dArr[value] + (target * next.getWeight());
                    double[] dArr2 = cHistogramArr[i].count;
                    dArr2[value] = dArr2[value] + next.getWeight();
                }
            }
            d += target * target * next.getWeight();
            d2 += next.getWeight();
        }
        for (int i2 = 0; i2 < cHistogramArr.length; i2++) {
            if (zArr[i2]) {
                for (int i3 = 1; i3 < cHistogramArr[i2].size(); i3++) {
                    double[] dArr3 = cHistogramArr[i2].sum;
                    int i4 = i3;
                    dArr3[i4] = dArr3[i4] + cHistogramArr[i2].sum[i3 - 1];
                    double[] dArr4 = cHistogramArr[i2].count;
                    int i5 = i3;
                    dArr4[i5] = dArr4[i5] + cHistogramArr[i2].count[i3 - 1];
                }
            }
        }
        return new DoublePair(d, d2);
    }

    protected static void computeHistogram2D(Instances instances, int i, int i2, Histogram2D histogram2D) {
        Iterator<Instance> it = instances.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            int value = (int) next.getValue(i);
            int value2 = (int) next.getValue(i2);
            double[] dArr = histogram2D.resp[value];
            dArr[value2] = dArr[value2] + (next.getTarget() * next.getWeight());
            double[] dArr2 = histogram2D.count[value];
            dArr2[value2] = dArr2[value2] + next.getWeight();
        }
    }

    protected static void computeTable(Histogram2D histogram2D, CHistogram cHistogram, CHistogram cHistogram2, Table table) {
        CHistogram cHistogram3;
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < histogram2D.resp[0].length; i++) {
            d += histogram2D.resp[0][i];
            table.resp[0][i][0] = d;
            d2 += histogram2D.count[0][i];
            table.count[0][i][0] = d2;
            cHistogram3 = cHistogram;
            fillTable(table, 0, i, cHistogram3, cHistogram2);
        }
        for (int i2 = 1; i2 < histogram2D.resp.length; i2++) {
            double d3 = 0.0d;
            double d4 = 0.0d;
            for (int i3 = 0; i3 < histogram2D.resp[i2].length; i3++) {
                d4 += histogram2D.resp[i2][i3];
                table.resp[i2][i3][0] = table.resp[i2 - 1][i3][0] + d4;
                d3 += histogram2D.count[i2][i3];
                table.count[i2][i3][0] = table.count[i2 - 1][i3][0] + d3;
                cHistogram3 = cHistogram;
                fillTable(table, i2, i3, cHistogram3, cHistogram2);
            }
        }
    }

    protected static void fillTable(Table table, int i, int i2, CHistogram cHistogram, CHistogram cHistogram2) {
        double[] dArr = table.count[i][i2];
        double[] dArr2 = table.resp[i][i2];
        dArr2[1] = cHistogram.sum[i] - dArr2[0];
        dArr2[2] = cHistogram2.sum[i2] - dArr2[0];
        dArr2[3] = (cHistogram.sum[cHistogram.size() - 1] - cHistogram.sum[i]) - dArr2[2];
        dArr[1] = cHistogram.count[i] - dArr[0];
        dArr[2] = cHistogram2.count[i2] - dArr[0];
        dArr[3] = (cHistogram.count[cHistogram.size() - 1] - cHistogram.count[i]) - dArr[2];
    }

    protected static void computeWeight(Element<IntPair> element, CHistogram[] cHistogramArr, Histogram2D histogram2D, double d, double d2) {
        int i = element.element.v1;
        int i2 = element.element.v2;
        int size = cHistogramArr[i].size();
        int size2 = cHistogramArr[i2].size();
        Table table = new Table(size, size2);
        computeTable(histogram2D, cHistogramArr[i], cHistogramArr[i2], table);
        double d3 = Double.POSITIVE_INFINITY;
        double[] dArr = new double[4];
        for (int i3 = 0; i3 < size - 1; i3++) {
            for (int i4 = 0; i4 < size2 - 1; i4++) {
                getPredictor(table, i3, i4, dArr);
                double rss = getRSS(table, i3, i4, d, dArr);
                if (rss < d3) {
                    d3 = rss;
                }
            }
        }
        element.weight = d3;
    }

    protected static void getPredictor(Table table, int i, int i2, double[] dArr) {
        double[] dArr2 = table.count[i][i2];
        double[] dArr3 = table.resp[i][i2];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = dArr2[i3] == 0.0d ? 0.0d : dArr3[i3] / dArr2[i3];
        }
    }

    protected static double getRSS(Table table, int i, int i2, double d, double[] dArr) {
        double[] dArr2 = table.count[i][i2];
        double[] dArr3 = table.resp[i][i2];
        double d2 = 0.0d;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            d2 += dArr[i3] * dArr[i3] * dArr2[i3];
        }
        double d3 = d + d2;
        double d4 = 0.0d;
        for (int i4 = 0; i4 < dArr.length; i4++) {
            d4 += dArr[i4] * dArr3[i4];
        }
        return d3 - (2.0d * d4);
    }

    static /* synthetic */ int[] $SWITCH_TABLE$mltk$core$Attribute$Type() {
        int[] iArr = $SWITCH_TABLE$mltk$core$Attribute$Type;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Attribute.Type.valuesCustom().length];
        try {
            iArr2[Attribute.Type.BINNED.ordinal()] = 3;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Attribute.Type.NOMINAL.ordinal()] = 1;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[Attribute.Type.NUMERIC.ordinal()] = 2;
        } catch (NoSuchFieldError unused3) {
        }
        $SWITCH_TABLE$mltk$core$Attribute$Type = iArr2;
        return iArr2;
    }
}
