package mltk.predictor.gam.tool;

import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import javax.tools.Diagnostic;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.io.InstancesReader;
import mltk.predictor.Regressor;
import mltk.predictor.gam.GAM;
import mltk.predictor.io.PredictorReader;
import mltk.util.Element;
import mltk.util.StatUtils;

/* loaded from: input_file:mltk/predictor/gam/tool/Diagnostics.class */
public class Diagnostics {

    /* loaded from: input_file:mltk/predictor/gam/tool/Diagnostics$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path", required = true)
        String attPath = null;

        @Argument(name = "-d", description = "dataset path", required = true)
        String datasetPath = null;

        @Argument(name = "-i", description = "input model path", required = true)
        String inputModelPath = null;

        @Argument(name = "-o", description = "output path", required = true)
        String outputPath = null;

        Options() {
        }
    }

    public static List<Element<int[]>> diagnose(GAM gam, Instances instances) {
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        List<int[]> terms = gam.getTerms();
        List<Regressor> regressors = gam.getRegressors();
        for (int i = 0; i < terms.size(); i++) {
            int[] iArr = terms.get(i);
            if (!hashMap.containsKey(iArr)) {
                hashMap.put(iArr, new ArrayList());
            }
            ((List) hashMap.get(iArr)).add(regressors.get(i));
        }
        double[] dArr = new double[instances.size()];
        for (int[] iArr2 : hashMap.keySet()) {
            List list = (List) hashMap.get(iArr2);
            for (int i2 = 0; i2 < instances.size(); i2++) {
                dArr[i2] = 0.0d;
                Instance instance = instances.get(i2);
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + ((Regressor) it.next()).regress(instance);
                }
            }
            arrayList.add(new Element(iArr2, StatUtils.variance(dArr)));
        }
        return arrayList;
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        CmdLineParser cmdLineParser = new CmdLineParser(Diagnostic.class, options);
        try {
            cmdLineParser.parse(strArr);
        } catch (IllegalArgumentException e) {
            cmdLineParser.printUsage();
            System.exit(1);
        }
        List<Element<int[]>> diagnose = diagnose((GAM) PredictorReader.read(options.inputModelPath), InstancesReader.read(options.attPath, options.datasetPath));
        Collections.sort(diagnose);
        Collections.reverse(diagnose);
        PrintWriter printWriter = new PrintWriter(options.outputPath);
        for (Element<int[]> element : diagnose) {
            printWriter.println(String.valueOf(Arrays.toString(element.element)) + ": " + element.weight);
        }
        printWriter.flush();
        printWriter.close();
    }
}
