package mltk.feature.selection;

import java.util.ArrayList;
import java.util.List;
import mltk.core.Attribute;
import mltk.core.Instances;
import mltk.predictor.BaggedEnsembleLearner;
import mltk.predictor.evaluation.Evaluator;
import mltk.predictor.tree.ensemble.ag.AdditiveGrovesLearner;
import mltk.util.StatUtils;
import mltk.util.tuple.DoublePair;
import mltk.util.tuple.Pair;

/* loaded from: input_file:mltk/feature/selection/BackwardElimination.class */
public class BackwardElimination {
    public static Pair<List<Attribute>, DoublePair> select(Instances instances, Instances instances2, BaggedEnsembleLearner baggedEnsembleLearner, int i) {
        List<Attribute> attributes = instances.getAttributes();
        ArrayList arrayList = new ArrayList(attributes);
        DoublePair doublePair = null;
        while (arrayList.size() != 0) {
            boolean z = false;
            instances.setAttributes(arrayList);
            doublePair = evaluateModel(instances, instances2, baggedEnsembleLearner, i);
            System.out.println("Mean: " + doublePair.v1 + " Std: " + doublePair.v2);
            int i2 = 0;
            while (i2 < arrayList.size()) {
                ArrayList arrayList2 = new ArrayList(arrayList);
                Attribute attribute = arrayList2.get(i2);
                arrayList2.remove(i2);
                instances.setAttributes(arrayList2);
                double evalRMSE = Evaluator.evalRMSE(baggedEnsembleLearner.build(instances), instances2);
                System.out.println("Testing: " + attribute.getName() + " RMSE: " + evalRMSE);
                if (doublePair.v1 - (doublePair.v2 * 3.0d) > evalRMSE || evalRMSE > doublePair.v1 + (doublePair.v2 * 3.0d)) {
                    i2++;
                } else {
                    arrayList.remove(i2);
                    z = true;
                    System.out.println("Eliminate: " + attribute.getName());
                }
            }
            if (!z) {
                break;
            }
        }
        instances.setAttributes(attributes);
        return new Pair<>(arrayList, doublePair);
    }

    public static Pair<List<Attribute>, DoublePair> select(Instances instances, Instances instances2, AdditiveGrovesLearner additiveGrovesLearner, int i, int i2, double d, int i3) {
        List<Attribute> attributes = instances.getAttributes();
        ArrayList arrayList = new ArrayList(attributes);
        DoublePair doublePair = null;
        while (arrayList.size() != 0) {
            boolean z = false;
            instances.setAttributes(arrayList);
            doublePair = evaluateModel(instances, instances2, additiveGrovesLearner, i, i2, d, i3);
            System.out.println("Mean: " + doublePair.v1 + " Std: " + doublePair.v2);
            int i4 = 0;
            while (i4 < arrayList.size()) {
                ArrayList arrayList2 = new ArrayList(arrayList);
                Attribute attribute = arrayList2.get(i4);
                System.out.println("Testing: " + attribute.getName());
                arrayList2.remove(i4);
                instances.setAttributes(arrayList2);
                double evalRMSE = Evaluator.evalRMSE(additiveGrovesLearner.runLayeredTraining(instances, i, i2, d), instances2);
                System.out.println("Testing: " + attribute.getName() + " RMSE: " + evalRMSE);
                if (doublePair.v1 - (doublePair.v2 * 3.0d) > evalRMSE || evalRMSE > doublePair.v1 + (doublePair.v2 * 3.0d)) {
                    i4++;
                } else {
                    arrayList.remove(i4);
                    z = true;
                    System.out.println("Eliminate: " + attribute.getName());
                }
            }
            if (!z) {
                break;
            }
        }
        instances.setAttributes(attributes);
        return new Pair<>(arrayList, doublePair);
    }

    private static DoublePair evaluateModel(Instances instances, Instances instances2, BaggedEnsembleLearner baggedEnsembleLearner, int i) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = Evaluator.evalRMSE(baggedEnsembleLearner.build(instances), instances2);
        }
        return new DoublePair(StatUtils.mean(dArr), StatUtils.std(dArr));
    }

    private static DoublePair evaluateModel(Instances instances, Instances instances2, AdditiveGrovesLearner additiveGrovesLearner, int i, int i2, double d, int i3) {
        double[] dArr = new double[i3];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr[i4] = Evaluator.evalRMSE(additiveGrovesLearner.runLayeredTraining(instances, i, i2, d), instances2);
            System.out.println("\tEvaluating model " + (i4 + 1) + " / " + i3 + "\t" + dArr[i4]);
        }
        return new DoublePair(StatUtils.mean(dArr), StatUtils.std(dArr));
    }
}
