package mltk.predictor.gam;

import java.util.ArrayList;
import java.util.List;
import mltk.cmdline.Argument;
import mltk.cmdline.CmdLineParser;
import mltk.core.Attribute;
import mltk.core.BinnedAttribute;
import mltk.core.Instances;
import mltk.core.NominalAttribute;
import mltk.core.io.InstancesReader;
import mltk.predictor.BaggedEnsemble;
import mltk.predictor.BaggedEnsembleLearner;
import mltk.predictor.Bagging;
import mltk.predictor.BoostedEnsemble;
import mltk.predictor.HoldoutValidatedLearner;
import mltk.predictor.Learner;
import mltk.predictor.Regressor;
import mltk.predictor.evaluation.Metric;
import mltk.predictor.evaluation.MetricFactory;
import mltk.predictor.function.CompressionUtils;
import mltk.predictor.function.Function1D;
import mltk.predictor.function.LineCutter;
import mltk.predictor.io.PredictorWriter;
import mltk.util.OptimUtils;
import mltk.util.Random;
import weka.core.json.JSONInstances;

/* loaded from: input_file:mltk/predictor/gam/GAMLearner.class */
public class GAMLearner extends HoldoutValidatedLearner {
    private static /* synthetic */ int[] $SWITCH_TABLE$mltk$predictor$Learner$Task;
    private boolean verbose = false;
    private int baggingIters = 100;
    private int maxNumIters = -1;
    private int maxNumLeaves = 3;
    private double learningRate = 1.0d;
    private Learner.Task task = Learner.Task.REGRESSION;

    /* loaded from: input_file:mltk/predictor/gam/GAMLearner$Options.class */
    static class Options {

        @Argument(name = "-r", description = "attribute file path")
        String attPath = null;

        @Argument(name = "-t", description = "train set path", required = true)
        String trainPath = null;

        @Argument(name = "-v", description = "valid set path")
        String validPath = null;

        @Argument(name = "-o", description = "output model path")
        String outputModelPath = null;

        @Argument(name = "-g", description = "task between classification (c) and regression (r) (default: r)")
        String task = "r";

        @Argument(name = "-e", description = "evaluation metric (default: default metric of task)")
        String metric = null;

        @Argument(name = "-b", description = "base learner (default: tr:3:100)")
        String baseLearner = "tr:3:100";

        @Argument(name = "-m", description = "maximum number of iterations", required = true)
        int maxNumIters = -1;

        @Argument(name = "-s", description = "seed of the random number generator (default: 0)")
        long seed = 0;

        @Argument(name = "-l", description = "learning rate (default: 0.01)")
        double learningRate = 0.01d;

        Options() {
        }
    }

    public GAMLearner() {
        this.metric = this.task.getDefaultMetric();
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    public int getBaggingIters() {
        return this.baggingIters;
    }

    public void setBaggingIters(int i) {
        this.baggingIters = i;
    }

    public int getMaxNumIters() {
        return this.maxNumIters;
    }

    public void setMaxNumIters(int i) {
        this.maxNumIters = i;
    }

    public int getMaxNumLeaves() {
        return this.maxNumLeaves;
    }

    public void setMaxNumLeaves(int i) {
        this.maxNumLeaves = i;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public Learner.Task getTask() {
        return this.task;
    }

    public void setTask(Learner.Task task) {
        this.task = task;
    }

    public void setBaseLearner(String str) {
        String[] split = str.split(JSONInstances.SPARSE_SEPARATOR);
        String str2 = split[0];
        switch (str2.hashCode()) {
            case 3184:
                if (!str2.equals("cs")) {
                }
                return;
            case 3710:
                if (str2.equals("tr")) {
                    int parseInt = Integer.parseInt(split[1]);
                    int parseInt2 = Integer.parseInt(split[2]);
                    setMaxNumLeaves(parseInt);
                    setBaggingIters(parseInt2);
                    return;
                }
                return;
            default:
                return;
        }
    }

    public GAM buildClassifier(Instances instances, Instances instances2, int i, int i2) {
        GAM gam = new GAM();
        double[] dArr = new double[instances.size()];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = instances.get(i3).getTarget();
        }
        List<Attribute> attributes = instances.getAttributes();
        ArrayList arrayList = new ArrayList(attributes.size());
        for (int i4 = 0; i4 < attributes.size(); i4++) {
            arrayList.add(new BoostedEnsemble());
        }
        Instances[] createBags = Bagging.createBags(instances, this.baggingIters);
        LineCutter lineCutter = new LineCutter(true);
        lineCutter.setNumIntervals(i2);
        BaggedEnsembleLearner baggedEnsembleLearner = new BaggedEnsembleLearner(createBags.length, lineCutter);
        double[] dArr2 = new double[instances.size()];
        double[] dArr3 = new double[instances.size()];
        OptimUtils.computePseudoResidual(dArr2, dArr, dArr3);
        double[] dArr4 = new double[instances2.size()];
        ArrayList arrayList2 = new ArrayList(i);
        for (int i5 = 0; i5 < i; i5++) {
            int size = i5 % attributes.size();
            for (int i6 = 0; i6 < instances.size(); i6++) {
                instances.get(i6).setTarget(dArr3[i6]);
            }
            BoostedEnsemble boostedEnsemble = (BoostedEnsemble) arrayList.get(size);
            lineCutter.setAttributeIndex(size);
            BaggedEnsemble build = baggedEnsembleLearner.build(createBags);
            if (this.learningRate != 1.0d) {
                for (int i7 = 0; i7 < build.size(); i7++) {
                    ((Function1D) build.get(i7)).multiply(this.learningRate);
                }
            }
            boostedEnsemble.add(build);
            for (int i8 = 0; i8 < instances.size(); i8++) {
                int i9 = i8;
                dArr2[i9] = dArr2[i9] + build.regress(instances.get(i8));
                dArr3[i8] = OptimUtils.getPseudoResidual(dArr2[i8], dArr[i8]);
            }
            for (int i10 = 0; i10 < instances2.size(); i10++) {
                int i11 = i10;
                dArr4[i11] = dArr4[i11] + build.regress(instances2.get(i10));
            }
            double eval = this.metric.eval(dArr4, instances2);
            arrayList2.add(Double.valueOf(eval));
            if (this.verbose) {
                System.out.println("Iteration " + i5 + " Feature " + size + ": " + eval);
            }
        }
        double worstValue = this.metric.worstValue();
        int i12 = -1;
        for (int i13 = 0; i13 < arrayList2.size(); i13++) {
            if (this.metric.isFirstBetter(((Double) arrayList2.get(i13)).doubleValue(), worstValue)) {
                worstValue = ((Double) arrayList2.get(i13)).doubleValue();
                i12 = i13;
            }
        }
        int size2 = i12 / attributes.size();
        int size3 = i12 % attributes.size();
        for (int i14 = 0; i14 < arrayList.size(); i14++) {
            BoostedEnsemble boostedEnsemble2 = (BoostedEnsemble) arrayList.get(i14);
            for (int size4 = boostedEnsemble2.size(); size4 > size2 + 1; size4--) {
                boostedEnsemble2.removeLast();
            }
            if (i14 > size3) {
                boostedEnsemble2.removeLast();
            }
        }
        for (int i15 = 0; i15 < dArr.length; i15++) {
            instances.get(i15).setTarget(dArr[i15]);
        }
        for (int i16 = 0; i16 < arrayList.size(); i16++) {
            BoostedEnsemble boostedEnsemble3 = (BoostedEnsemble) arrayList.get(i16);
            Attribute attribute = attributes.get(i16);
            int index = attribute.getIndex();
            Function1D compress = CompressionUtils.compress(index, boostedEnsemble3);
            Regressor regressor = compress;
            if (attribute.getType() == Attribute.Type.BINNED) {
                regressor = CompressionUtils.convert(((BinnedAttribute) attribute).getNumBins(), compress);
            } else if (attribute.getType() == Attribute.Type.NOMINAL) {
                regressor = CompressionUtils.convert(((NominalAttribute) attribute).getCardinality(), compress);
            }
            gam.add(new int[]{index}, regressor);
        }
        return gam;
    }

    public GAM buildClassifier(Instances instances, int i, int i2) {
        GAM gam = new GAM();
        double[] dArr = new double[instances.size()];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = instances.get(i3).getTarget();
        }
        List<Attribute> attributes = instances.getAttributes();
        ArrayList arrayList = new ArrayList(attributes.size());
        for (int i4 = 0; i4 < attributes.size(); i4++) {
            arrayList.add(new BoostedEnsemble());
        }
        Instances[] createBags = Bagging.createBags(instances, this.baggingIters);
        LineCutter lineCutter = new LineCutter(true);
        lineCutter.setNumIntervals(i2);
        BaggedEnsembleLearner baggedEnsembleLearner = new BaggedEnsembleLearner(createBags.length, lineCutter);
        double[] dArr2 = new double[instances.size()];
        double[] dArr3 = new double[instances.size()];
        OptimUtils.computePseudoResidual(dArr2, dArr, dArr3);
        for (int i5 = 0; i5 < i; i5++) {
            int size = i5 % attributes.size();
            for (int i6 = 0; i6 < instances.size(); i6++) {
                instances.get(i6).setTarget(dArr3[i6]);
            }
            BoostedEnsemble boostedEnsemble = (BoostedEnsemble) arrayList.get(size);
            lineCutter.setAttributeIndex(size);
            BaggedEnsemble build = baggedEnsembleLearner.build(createBags);
            if (this.learningRate != 1.0d) {
                for (int i7 = 0; i7 < build.size(); i7++) {
                    ((Function1D) build.get(i7)).multiply(this.learningRate);
                }
            }
            boostedEnsemble.add(build);
            for (int i8 = 0; i8 < instances.size(); i8++) {
                int i9 = i8;
                dArr2[i9] = dArr2[i9] + build.regress(instances.get(i8));
                dArr3[i8] = OptimUtils.getPseudoResidual(dArr2[i8], dArr[i8]);
            }
            double eval = this.metric.eval(dArr2, dArr);
            if (this.verbose) {
                System.out.println("Iteration " + i5 + " Feature " + size + ": " + eval);
            }
        }
        for (int i10 = 0; i10 < dArr.length; i10++) {
            instances.get(i10).setTarget(dArr[i10]);
        }
        for (int i11 = 0; i11 < arrayList.size(); i11++) {
            BoostedEnsemble boostedEnsemble2 = (BoostedEnsemble) arrayList.get(i11);
            Attribute attribute = attributes.get(i11);
            int index = attribute.getIndex();
            Function1D compress = CompressionUtils.compress(index, boostedEnsemble2);
            Regressor regressor = compress;
            if (attribute.getType() == Attribute.Type.BINNED) {
                regressor = CompressionUtils.convert(((BinnedAttribute) attribute).getNumBins(), compress);
            } else if (attribute.getType() == Attribute.Type.NOMINAL) {
                regressor = CompressionUtils.convert(((NominalAttribute) attribute).getCardinality(), compress);
            }
            gam.add(new int[]{index}, regressor);
        }
        return gam;
    }

    public GAM buildRegressor(Instances instances, Instances instances2, int i, int i2) {
        GAM gam = new GAM();
        List<Attribute> attributes = instances.getAttributes();
        ArrayList arrayList = new ArrayList(attributes.size());
        for (int i3 = 0; i3 < attributes.size(); i3++) {
            arrayList.add(new BoostedEnsemble());
        }
        double[] dArr = new double[instances.size()];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr[i4] = instances.get(i4).getTarget();
        }
        Instances[] createBags = Bagging.createBags(instances, this.baggingIters);
        LineCutter lineCutter = new LineCutter();
        lineCutter.setNumIntervals(i2);
        BaggedEnsembleLearner baggedEnsembleLearner = new BaggedEnsembleLearner(createBags.length, lineCutter);
        double[] dArr2 = new double[instances.size()];
        double[] dArr3 = new double[instances2.size()];
        double[] dArr4 = new double[instances2.size()];
        for (int i5 = 0; i5 < instances.size(); i5++) {
            dArr2[i5] = instances.get(i5).getTarget();
        }
        for (int i6 = 0; i6 < instances2.size(); i6++) {
            dArr4[i6] = instances2.get(i6).getTarget();
        }
        ArrayList arrayList2 = new ArrayList(i);
        for (int i7 = 0; i7 < i; i7++) {
            int size = i7 % attributes.size();
            BoostedEnsemble boostedEnsemble = (BoostedEnsemble) arrayList.get(size);
            for (int i8 = 0; i8 < dArr2.length; i8++) {
                instances.get(i8).setTarget(dArr2[i8]);
            }
            lineCutter.setAttributeIndex(size);
            BaggedEnsemble build = baggedEnsembleLearner.build(createBags);
            if (this.learningRate != 1.0d) {
                for (int i9 = 0; i9 < build.size(); i9++) {
                    ((Function1D) build.get(i9)).multiply(this.learningRate);
                }
            }
            boostedEnsemble.add(build);
            for (int i10 = 0; i10 < dArr2.length; i10++) {
                int i11 = i10;
                dArr2[i11] = dArr2[i11] - build.regress(instances.get(i10));
            }
            for (int i12 = 0; i12 < dArr4.length; i12++) {
                double regress = build.regress(instances2.get(i12));
                int i13 = i12;
                dArr3[i13] = dArr3[i13] + regress;
                int i14 = i12;
                dArr4[i14] = dArr4[i14] - regress;
            }
            double eval = this.metric.eval(dArr3, instances2);
            arrayList2.add(Double.valueOf(eval));
            if (this.verbose) {
                System.out.println("Iteration " + i7 + " Feature " + size + ": " + eval);
            }
        }
        double worstValue = this.metric.worstValue();
        int i15 = -1;
        for (int i16 = 0; i16 < arrayList2.size(); i16++) {
            if (this.metric.isFirstBetter(((Double) arrayList2.get(i16)).doubleValue(), worstValue)) {
                worstValue = ((Double) arrayList2.get(i16)).doubleValue();
                i15 = i16;
            }
        }
        int size2 = i15 / attributes.size();
        int size3 = i15 % attributes.size();
        for (int i17 = 0; i17 < arrayList.size(); i17++) {
            BoostedEnsemble boostedEnsemble2 = (BoostedEnsemble) arrayList.get(i17);
            for (int size4 = boostedEnsemble2.size(); size4 > size2 + 1; size4--) {
                boostedEnsemble2.removeLast();
            }
            if (i17 > size3) {
                boostedEnsemble2.removeLast();
            }
        }
        for (int i18 = 0; i18 < dArr.length; i18++) {
            instances.get(i18).setTarget(dArr[i18]);
        }
        for (int i19 = 0; i19 < arrayList.size(); i19++) {
            BoostedEnsemble boostedEnsemble3 = (BoostedEnsemble) arrayList.get(i19);
            Attribute attribute = attributes.get(i19);
            int index = attribute.getIndex();
            Function1D compress = CompressionUtils.compress(index, boostedEnsemble3);
            Regressor regressor = compress;
            if (attribute.getType() == Attribute.Type.BINNED) {
                regressor = CompressionUtils.convert(((BinnedAttribute) attribute).getNumBins(), compress);
            } else if (attribute.getType() == Attribute.Type.NOMINAL) {
                regressor = CompressionUtils.convert(((NominalAttribute) attribute).getCardinality(), compress);
            }
            gam.add(new int[]{index}, regressor);
        }
        return gam;
    }

    public GAM buildRegressor(Instances instances, int i, int i2) {
        GAM gam = new GAM();
        List<Attribute> attributes = instances.getAttributes();
        ArrayList arrayList = new ArrayList(attributes.size());
        for (int i3 = 0; i3 < attributes.size(); i3++) {
            arrayList.add(new BoostedEnsemble());
        }
        double[] dArr = new double[instances.size()];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr[i4] = instances.get(i4).getTarget();
        }
        Instances[] createBags = Bagging.createBags(instances, this.baggingIters);
        LineCutter lineCutter = new LineCutter();
        lineCutter.setNumIntervals(i2);
        BaggedEnsembleLearner baggedEnsembleLearner = new BaggedEnsembleLearner(createBags.length, lineCutter);
        double[] dArr2 = new double[instances.size()];
        double[] dArr3 = new double[instances.size()];
        for (int i5 = 0; i5 < instances.size(); i5++) {
            dArr3[i5] = instances.get(i5).getTarget();
        }
        for (int i6 = 0; i6 < i; i6++) {
            int size = i6 % attributes.size();
            BoostedEnsemble boostedEnsemble = (BoostedEnsemble) arrayList.get(size);
            for (int i7 = 0; i7 < dArr3.length; i7++) {
                instances.get(i7).setTarget(dArr3[i7]);
            }
            lineCutter.setAttributeIndex(size);
            BaggedEnsemble build = baggedEnsembleLearner.build(createBags);
            if (this.learningRate != 1.0d) {
                for (int i8 = 0; i8 < build.size(); i8++) {
                    ((Function1D) build.get(i8)).multiply(this.learningRate);
                }
            }
            boostedEnsemble.add(build);
            for (int i9 = 0; i9 < dArr3.length; i9++) {
                double regress = build.regress(instances.get(i9));
                int i10 = i9;
                dArr2[i10] = dArr2[i10] + regress;
                int i11 = i9;
                dArr3[i11] = dArr3[i11] - regress;
            }
            double eval = this.metric.eval(dArr2, dArr);
            if (this.verbose) {
                System.out.println("Iteration " + i6 + " Feature " + size + ": " + eval);
            }
        }
        for (int i12 = 0; i12 < dArr.length; i12++) {
            instances.get(i12).setTarget(dArr[i12]);
        }
        for (int i13 = 0; i13 < arrayList.size(); i13++) {
            BoostedEnsemble boostedEnsemble2 = (BoostedEnsemble) arrayList.get(i13);
            Attribute attribute = attributes.get(i13);
            int index = attribute.getIndex();
            Function1D compress = CompressionUtils.compress(index, boostedEnsemble2);
            Regressor regressor = compress;
            if (attribute.getType() == Attribute.Type.BINNED) {
                regressor = CompressionUtils.convert(((BinnedAttribute) attribute).getNumBins(), compress);
            } else if (attribute.getType() == Attribute.Type.NOMINAL) {
                regressor = CompressionUtils.convert(((NominalAttribute) attribute).getCardinality(), compress);
            }
            gam.add(new int[]{index}, regressor);
        }
        return gam;
    }

    @Override // mltk.predictor.Learner
    public GAM build(Instances instances) {
        GAM gam = null;
        if (this.maxNumIters < 0) {
            this.maxNumIters = instances.getAttributes().size() * 20;
        }
        if (this.metric == null) {
            this.metric = this.task.getDefaultMetric();
        }
        switch ($SWITCH_TABLE$mltk$predictor$Learner$Task()[this.task.ordinal()]) {
            case 1:
                if (this.validSet == null) {
                    gam = buildClassifier(instances, this.maxNumIters, this.maxNumLeaves);
                    break;
                } else {
                    gam = buildClassifier(instances, this.validSet, this.maxNumIters, this.maxNumLeaves);
                    break;
                }
            case 2:
                if (this.validSet == null) {
                    gam = buildRegressor(instances, this.maxNumIters, this.maxNumLeaves);
                    break;
                } else {
                    gam = buildRegressor(instances, this.validSet, this.maxNumIters, this.maxNumLeaves);
                    break;
                }
        }
        return gam;
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        CmdLineParser cmdLineParser = new CmdLineParser(GAMLearner.class, options);
        Learner.Task task = null;
        Metric metric = null;
        try {
            cmdLineParser.parse(strArr);
            task = Learner.Task.getEnum(options.task);
            metric = options.metric == null ? task.getDefaultMetric() : MetricFactory.getMetric(options.metric);
        } catch (IllegalArgumentException e) {
            cmdLineParser.printUsage();
            System.exit(1);
        }
        Random.getInstance().setSeed(options.seed);
        Instances read = InstancesReader.read(options.attPath, options.trainPath);
        GAMLearner gAMLearner = new GAMLearner();
        gAMLearner.setBaseLearner(options.baseLearner);
        gAMLearner.setMaxNumIters(options.maxNumIters);
        gAMLearner.setLearningRate(options.learningRate);
        gAMLearner.setTask(task);
        gAMLearner.setMetric(metric);
        gAMLearner.setVerbose(true);
        if (options.validPath != null) {
            gAMLearner.setValidSet(InstancesReader.read(options.attPath, options.validPath));
        }
        long currentTimeMillis = System.currentTimeMillis();
        GAM build = gAMLearner.build(read);
        System.out.println("Time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        if (options.outputModelPath != null) {
            PredictorWriter.write(build, options.outputModelPath);
        }
    }

    static /* synthetic */ int[] $SWITCH_TABLE$mltk$predictor$Learner$Task() {
        int[] iArr = $SWITCH_TABLE$mltk$predictor$Learner$Task;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Learner.Task.valuesCustom().length];
        try {
            iArr2[Learner.Task.CLASSIFICATION.ordinal()] = 1;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Learner.Task.REGRESSION.ordinal()] = 2;
        } catch (NoSuchFieldError unused2) {
        }
        $SWITCH_TABLE$mltk$predictor$Learner$Task = iArr2;
        return iArr2;
    }
}
