package mltk.predictor.gam.tool;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instance;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.core.io.InstancesReader;
import mltk.predictor.Regressor;
import mltk.predictor.function.CubicSpline;
import mltk.predictor.gam.GAM;
import mltk.predictor.io.PredictorReader;

/* loaded from: input_file:mltk/predictor/gam/tool/Visualizer.class */
public class Visualizer {
    private static /* synthetic */ int[] $SWITCH_TABLE$mltk$core$Attribute$Type;

    /* loaded from: input_file:mltk/predictor/gam/tool/Visualizer$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path", required = true)
        String attPath = null;

        @Argument(name = "-d", description = "dataset path", required = true)
        String datasetPath = null;

        @Argument(name = "-i", description = "input model path", required = true)
        String inputModelPath = null;

        @Argument(name = "-o", description = "output directory path", required = true)
        String dirPath = null;

        @Argument(name = "-t", description = "output terminal (default: png)")
        String terminal = "png";

        Options() {
        }
    }

    /* loaded from: input_file:mltk/predictor/gam/tool/Visualizer$Terminal.class */
    public enum Terminal {
        PNG("png"),
        PDF("pdf");

        String term;

        Terminal(String str) {
            this.term = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.term;
        }

        public static Terminal getEnum(String str) {
            for (Terminal terminal : valuesCustom()) {
                if (terminal.term.compareTo(str) == 0) {
                    return terminal;
                }
            }
            throw new IllegalArgumentException("Invalid Terminal value: " + str);
        }

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Terminal[] valuesCustom() {
            Terminal[] valuesCustom = values();
            int length = valuesCustom.length;
            Terminal[] terminalArr = new Terminal[length];
            System.arraycopy(valuesCustom, 0, terminalArr, 0, length);
            return terminalArr;
        }
    }

    public static void generateGnuplotScripts(GAM gam, Instances instances, String str, Terminal terminal) throws IOException {
        List<Attribute> attributes = instances.getAttributes();
        List<int[]> terms = gam.getTerms();
        List<Regressor> regressors = gam.getRegressors();
        File file = new File(str);
        if (!file.exists()) {
            file.mkdirs();
        }
        Instance instance = new Instance(new double[attributes.size()]);
        String terminal2 = terminal.toString();
        for (int i = 0; i < terms.size(); i++) {
            int[] iArr = terms.get(i);
            Regressor regressor = regressors.get(i);
            if (iArr.length == 1) {
                Attribute attribute = attributes.get(iArr[0]);
                switch ($SWITCH_TABLE$mltk$core$Attribute$Type()[attribute.getType().ordinal()]) {
                    case 1:
                        if (((NominalAttribute) attribute).getStates().length == 1) {
                            break;
                        }
                        break;
                    case 3:
                        if (((BinnedAttribute) attribute).getNumBins() == 1) {
                            break;
                        }
                        break;
                }
                PrintWriter printWriter = new PrintWriter(String.valueOf(file.getAbsolutePath()) + File.separator + attribute.getName() + ".plt");
                printWriter.printf("set term %s\n", terminal2);
                printWriter.printf("set output \"%s.%s\"\n", attribute.getName(), terminal2);
                printWriter.println("set datafile separator \"\t\"");
                switch ($SWITCH_TABLE$mltk$core$Attribute$Type()[attribute.getType().ordinal()]) {
                    case 1:
                        printWriter.println("set style data histogram");
                        printWriter.println("set style histogram cluster gap 1");
                        printWriter.println("set style fill solid border -1");
                        printWriter.println("set boxwidth 0.9");
                        printWriter.println("plot \"-\" u 2:xtic(1) t \"\"");
                        printWriter.println("set xtic rotate by -90");
                        String[] states = ((NominalAttribute) attribute).getStates();
                        for (int i2 = 0; i2 < states.length; i2++) {
                            instance.setValue(iArr[0], i2);
                            printWriter.printf("%s\t%f\n", states[i2], Double.valueOf(regressor.regress(instance)));
                        }
                        printWriter.println("e");
                        break;
                    case 2:
                    default:
                        HashSet hashSet = new HashSet();
                        Iterator<Instance> it = instances.iterator();
                        while (it.hasNext()) {
                            hashSet.add(Double.valueOf(it.next().getValue(iArr[0])));
                        }
                        ArrayList arrayList = new ArrayList(hashSet);
                        Collections.sort(arrayList);
                        printWriter.printf("set xrange[%f:%f]\n", arrayList.get(0), arrayList.get(arrayList.size() - 1));
                        if (regressor instanceof CubicSpline) {
                            CubicSpline cubicSpline = (CubicSpline) regressor;
                            printWriter.println("z(x) = x < 0 ? 0 : x ** 3");
                            printWriter.println("h(x, k) = z(x - k)");
                            double[] knots = cubicSpline.getKnots();
                            double[] coefficients = cubicSpline.getCoefficients();
                            StringBuilder sb = new StringBuilder();
                            sb.append("plot ").append(cubicSpline.getIntercept());
                            sb.append(" + ").append(coefficients[0]).append(" * x");
                            sb.append(" + ").append(coefficients[1]).append(" * (x ** 2)");
                            sb.append(" + ").append(coefficients[2]).append(" * (x ** 3)");
                            for (int i3 = 0; i3 < knots.length; i3++) {
                                sb.append(" + ").append(coefficients[i3 + 3]).append(" * ");
                                sb.append("h(x, ").append(knots[i3]).append(")");
                            }
                            sb.append(" t \"\"");
                            printWriter.println(sb.toString());
                            break;
                        } else {
                            printWriter.println("plot \"-\" u 1:2 w lp t \"\"");
                            Iterator it2 = arrayList.iterator();
                            while (it2.hasNext()) {
                                double doubleValue = ((Double) it2.next()).doubleValue();
                                instance.setValue(iArr[0], doubleValue);
                                printWriter.printf("%f\t%f\n", Double.valueOf(doubleValue), Double.valueOf(regressor.regress(instance)));
                            }
                            break;
                        }
                    case 3:
                        int numBins = ((BinnedAttribute) attribute).getNumBins();
                        printWriter.printf("set xrange[0:%d]\n", Integer.valueOf(numBins - 1));
                        printWriter.println("plot \"-\" u 1:2 w lp t \"\"");
                        for (int i4 = 0; i4 < numBins; i4++) {
                            instance.setValue(iArr[0], i4);
                            printWriter.printf("%d\t%f\n", Integer.valueOf(i4), Double.valueOf(regressor.regress(instance)));
                        }
                        printWriter.println("e");
                        break;
                }
                printWriter.flush();
                printWriter.close();
            } else if (iArr.length == 2) {
                Attribute attribute2 = attributes.get(iArr[0]);
                Attribute attribute3 = attributes.get(iArr[1]);
                PrintWriter printWriter2 = new PrintWriter(String.valueOf(file.getAbsolutePath()) + File.separator + attribute2.getName() + "_" + attribute3.getName() + ".plt");
                printWriter2.printf("set term %s\n", terminal2);
                printWriter2.printf("set output \"%s_%s.%s\"\n", attribute2.getName(), attribute3.getName(), terminal2);
                printWriter2.println("set datafile separator \"\t\"");
                int i5 = 0;
                if (attribute2.getType() == Attribute.Type.BINNED) {
                    i5 = ((BinnedAttribute) attribute2).getNumBins();
                } else if (attribute2.getType() == Attribute.Type.NOMINAL) {
                    i5 = ((NominalAttribute) attribute2).getCardinality();
                }
                int i6 = 0;
                if (attribute2.getType() == Attribute.Type.BINNED) {
                    i6 = ((BinnedAttribute) attribute3).getNumBins();
                } else if (attribute2.getType() == Attribute.Type.NOMINAL) {
                    i6 = ((NominalAttribute) attribute3).getCardinality();
                }
                printWriter2.printf("set xrange[0:%d]\n", Integer.valueOf(i6 - 1));
                printWriter2.printf("set yrange[0:%d]\n", Integer.valueOf(i5 - 1));
                if (attribute2.getType() == Attribute.Type.NOMINAL) {
                    printWriter2.print("set ytics(");
                    String[] states2 = ((NominalAttribute) attribute2).getStates();
                    for (int i7 = 0; i7 < states2.length; i7++) {
                        printWriter2.printf("%s %d", states2[i7], Integer.valueOf(i7));
                    }
                    printWriter2.println(")");
                }
                if (attribute3.getType() == Attribute.Type.NOMINAL) {
                    printWriter2.print("set xtics(");
                    String[] states3 = ((NominalAttribute) attribute3).getStates();
                    for (int i8 = 0; i8 < states3.length; i8++) {
                        printWriter2.printf("%s %d", states3[i8], Integer.valueOf(i8));
                    }
                    printWriter2.println(")");
                }
                printWriter2.println("plot \"-\" matrix with image t \"\"");
                for (int i9 = 0; i9 < i5; i9++) {
                    instance.setValue(iArr[0], i9);
                    for (int i10 = 0; i10 < i6; i10++) {
                        instance.setValue(iArr[1], i10);
                        printWriter2.print(String.valueOf(regressor.regress(instance)) + "\t");
                    }
                    printWriter2.println();
                }
                printWriter2.println("e");
                printWriter2.flush();
                printWriter2.close();
            }
        }
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        CmdLineParser cmdLineParser = new CmdLineParser(Visualizer.class, options);
        try {
            cmdLineParser.parse(strArr);
        } catch (IllegalArgumentException e) {
            cmdLineParser.printUsage();
            System.exit(1);
        }
        generateGnuplotScripts((GAM) PredictorReader.read(options.inputModelPath, GAM.class), InstancesReader.read(options.attPath, options.datasetPath), options.dirPath, Terminal.getEnum(options.terminal));
    }

    static /* synthetic */ int[] $SWITCH_TABLE$mltk$core$Attribute$Type() {
        int[] iArr = $SWITCH_TABLE$mltk$core$Attribute$Type;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Attribute.Type.valuesCustom().length];
        try {
            iArr2[Attribute.Type.BINNED.ordinal()] = 3;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Attribute.Type.NOMINAL.ordinal()] = 1;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[Attribute.Type.NUMERIC.ordinal()] = 2;
        } catch (NoSuchFieldError unused3) {
        }
        $SWITCH_TABLE$mltk$core$Attribute$Type = iArr2;
        return iArr2;
    }
}
