package mltk.predictor.tree.ensemble.brt;

import java.io.BufferedReader;
import java.io.PrintWriter;
import java.util.Iterator;
import mltk.core.Instance;
import mltk.predictor.Predictor;
import mltk.predictor.ProbabilisticClassifier;
import mltk.predictor.Regressor;
import mltk.predictor.tree.RegressionTree;
import mltk.util.StatUtils;
import mltk.util.VectorUtils;

/* loaded from: input_file:mltk/predictor/tree/ensemble/brt/BRT.class */
public class BRT implements ProbabilisticClassifier, Regressor {
    protected BoostedRegressionTrees[] trees;

    public BRT() {
    }

    public BRT(int i) {
        this.trees = new BoostedRegressionTrees[i];
        for (int i2 = 0; i2 < this.trees.length; i2++) {
            this.trees[i2] = new BoostedRegressionTrees();
        }
    }

    public BoostedRegressionTrees getRegressionTreeList(int i) {
        return this.trees[i];
    }

    @Override // mltk.predictor.Classifier
    public int classify(Instance instance) {
        return StatUtils.indexOfMax(predictProbabilities(instance));
    }

    @Override // mltk.core.Writable
    public void read(BufferedReader bufferedReader) throws Exception {
        this.trees = new BoostedRegressionTrees[Integer.parseInt(bufferedReader.readLine().split(": ")[1])];
        int i = 0;
        while (i < this.trees.length) {
            this.trees[i] = new BoostedRegressionTrees();
            int parseInt = Integer.parseInt(bufferedReader.readLine().split(": ")[1]);
            while (0 < parseInt) {
                RegressionTree regressionTree = new RegressionTree();
                regressionTree.read(bufferedReader);
                this.trees[i].add(regressionTree);
                bufferedReader.readLine();
                i++;
            }
            i++;
        }
    }

    @Override // mltk.core.Writable
    public void write(PrintWriter printWriter) throws Exception {
        printWriter.printf("[Predictor: %s]\n", getClass().getCanonicalName());
        printWriter.println("K: " + this.trees.length);
        for (BoostedRegressionTrees boostedRegressionTrees : this.trees) {
            printWriter.println("Length: " + boostedRegressionTrees.size());
            Iterator<RegressionTree> it = boostedRegressionTrees.iterator();
            while (it.hasNext()) {
                it.next().write(printWriter);
            }
            printWriter.println();
        }
    }

    @Override // mltk.predictor.Regressor
    public double regress(Instance instance) {
        return this.trees[0].regress(instance);
    }

    @Override // mltk.predictor.ProbabilisticClassifier
    public double[] predictProbabilities(Instance instance) {
        double[] dArr = new double[this.trees.length];
        double[] dArr2 = new double[this.trees.length];
        for (int i = 0; i < this.trees.length; i++) {
            dArr2[i] = this.trees[i].regress(instance);
        }
        double max = StatUtils.max(dArr2);
        double d = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = Math.exp(dArr2[i2] - max);
            d += dArr[i2];
        }
        VectorUtils.divide(dArr, d);
        return dArr;
    }

    @Override // mltk.core.Copyable
    /* renamed from: copy */
    public Predictor copy2() {
        BRT brt = new BRT(this.trees.length);
        for (int i = 0; i < this.trees.length; i++) {
            Iterator<RegressionTree> it = this.trees[i].iterator();
            while (it.hasNext()) {
                brt.trees[i].add(it.next().copy2());
            }
        }
        return brt;
    }
}
