package mltk.predictor.function;

import java.io.BufferedReader;
import java.io.PrintWriter;
import java.util.Arrays;
import mltk.core.Instance;
import mltk.predictor.Predictor;
import mltk.predictor.Regressor;
import mltk.util.ArrayUtils;
import mltk.util.VectorUtils;

/* loaded from: input_file:mltk/predictor/function/Function1D.class */
public class Function1D implements Regressor, UnivariateFunction {
    protected int attIndex;
    protected double[] splits;
    protected double[] predictions;

    public static Function1D getConstantFunction(int i, double d) {
        return new Function1D(i, new double[]{Double.POSITIVE_INFINITY}, new double[]{d});
    }

    public void setZero() {
        this.splits = new double[]{Double.POSITIVE_INFINITY};
        this.predictions = new double[]{0.0d};
    }

    public boolean isZero() {
        return ArrayUtils.isConstant(this.predictions, 0, this.predictions.length, 0.0d);
    }

    public boolean isConstant() {
        return ArrayUtils.isConstant(this.predictions, 1, this.predictions.length, this.predictions[0]);
    }

    public Function1D multiply(double d) {
        VectorUtils.multiply(this.predictions, d);
        return this;
    }

    public Function1D divide(double d) {
        VectorUtils.divide(this.predictions, d);
        return this;
    }

    public Function1D add(double d) {
        VectorUtils.add(this.predictions, d);
        return this;
    }

    public Function1D subtract(double d) {
        VectorUtils.subtract(this.predictions, d);
        return this;
    }

    public Function1D() {
    }

    public Function1D(int i, double[] dArr, double[] dArr2) {
        this.attIndex = i;
        this.splits = dArr;
        this.predictions = dArr2;
    }

    public Function1D add(Function1D function1D) {
        if (this.attIndex != function1D.attIndex) {
            throw new IllegalArgumentException("Cannot add functions on different terms");
        }
        int[] iArr = new int[function1D.splits.length - 1];
        int i = 0;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = Arrays.binarySearch(this.splits, function1D.splits[i2]);
            if (iArr[i2] < 0) {
                i++;
            }
        }
        if (i > 0) {
            double[] dArr = new double[this.splits.length + i];
            System.arraycopy(this.splits, 0, dArr, 0, this.splits.length);
            int length = this.splits.length;
            for (int i3 = 0; i3 < iArr.length; i3++) {
                if (iArr[i3] < 0) {
                    int i4 = length;
                    length++;
                    dArr[i4] = function1D.splits[i3];
                }
            }
            Arrays.sort(dArr);
            double[] dArr2 = new double[dArr.length];
            for (int i5 = 0; i5 < dArr.length; i5++) {
                dArr2[i5] = evaluate(dArr[i5]) + function1D.evaluate(dArr[i5]);
            }
            this.splits = dArr;
            this.predictions = dArr2;
        } else {
            for (int i6 = 0; i6 < this.splits.length; i6++) {
                double[] dArr3 = this.predictions;
                int i7 = i6;
                dArr3[i7] = dArr3[i7] + function1D.evaluate(this.splits[i6]);
            }
        }
        return this;
    }

    public int getAttributeIndex() {
        return this.attIndex;
    }

    public void setAttributeIndex(int i) {
        this.attIndex = i;
    }

    public double[] getSplits() {
        return this.splits;
    }

    public void setSplits(double[] dArr) {
        this.splits = dArr;
    }

    public double[] getPredictions() {
        return this.predictions;
    }

    public void setPredictions(double[] dArr) {
        this.predictions = dArr;
    }

    @Override // mltk.core.Writable
    public void read(BufferedReader bufferedReader) throws Exception {
        this.attIndex = Integer.parseInt(bufferedReader.readLine().split(": ")[1]);
        bufferedReader.readLine();
        this.splits = ArrayUtils.parseDoubleArray(bufferedReader.readLine());
        bufferedReader.readLine();
        this.predictions = ArrayUtils.parseDoubleArray(bufferedReader.readLine());
    }

    @Override // mltk.core.Writable
    public void write(PrintWriter printWriter) throws Exception {
        printWriter.printf("[Predictor: %s]\n", getClass().getCanonicalName());
        printWriter.println("AttIndex: " + this.attIndex);
        printWriter.println("Splits: " + this.splits.length);
        printWriter.println(Arrays.toString(this.splits));
        printWriter.println("Predictions: " + this.predictions.length);
        printWriter.println(Arrays.toString(this.predictions));
    }

    public int getSegmentIndex(Instance instance) {
        return getSegmentIndex(instance.getValue(this.attIndex));
    }

    public int getSegmentIndex(double d) {
        int binarySearch = Arrays.binarySearch(this.splits, d);
        if (binarySearch < 0) {
            binarySearch = (-binarySearch) - 1;
        }
        return binarySearch;
    }

    @Override // mltk.predictor.Regressor
    public double regress(Instance instance) {
        return this.predictions[getSegmentIndex(instance)];
    }

    @Override // mltk.predictor.function.UnivariateFunction
    public double evaluate(double d) {
        return this.predictions[getSegmentIndex(d)];
    }

    @Override // mltk.core.Copyable
    /* renamed from: copy */
    public Predictor copy2() {
        return new Function1D(this.attIndex, Arrays.copyOf(this.splits, this.splits.length), Arrays.copyOf(this.predictions, this.predictions.length));
    }
}
