package de.ismll.core.regression;

import de.ismll.core.Instance;
import de.ismll.core.Instances;
import de.ismll.core.regression.neuralnet.ActivationFunction;
import de.ismll.core.regression.neuralnet.ActivationIdentity;
import de.ismll.core.regression.neuralnet.FullyConnectedLayer;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:de/ismll/core/regression/MultiLayerPerceptron.class */
public class MultiLayerPerceptron implements Regression {
    private double initialLearnRate;
    private int numHiddenLayers;
    private int numNeurons;
    private int numTotalLayers;
    private double[][] parameterArray;
    private FullyConnectedLayer[] layerArray;
    private RealMatrix inverseHessian;
    private int numberOfParameters;
    private double[] finalParameters;
    private double alpha = 0.01d;
    private double beta = 1.0d;

    public MultiLayerPerceptron(double d, double d2, int i, int i2, int i3, ActivationFunction activationFunction, long j) {
        setInitialLearnRate(d);
        setNumHiddenLayers(i2);
        setNumNeurons(i3);
        setNumTotalLayers(getNumHiddenLayers() + 2);
        int i4 = i2 * i3;
        this.layerArray = new FullyConnectedLayer[getNumTotalLayers()];
        this.layerArray[0] = new FullyConnectedLayer(d, d2, i, i, getNumNeurons(), i4, activationFunction, j);
        for (int i5 = 1; i5 < getNumHiddenLayers() + 1; i5++) {
            this.layerArray[i5] = new FullyConnectedLayer(d, d2, i, getNumNeurons(), getNumNeurons(), i4, activationFunction, j);
        }
        this.layerArray[getNumHiddenLayers() + 1] = new FullyConnectedLayer(d, d2, i, getNumNeurons(), 1, i4, new ActivationIdentity(), j);
        int i6 = 0;
        for (int i7 = 0; i7 < this.layerArray.length; i7++) {
            i6 += (this.layerArray[i7].getInputSize() + 1) * this.layerArray[i7].getOutputSize();
        }
        this.numberOfParameters = i6;
    }

    public MultiLayerPerceptron(double d, double d2, int i, int i2, int i3, ActivationFunction activationFunction) {
        setInitialLearnRate(d);
        setNumHiddenLayers(i2);
        setNumNeurons(i3);
        setNumTotalLayers(getNumHiddenLayers() + 2);
        int i4 = i2 * i3;
        this.layerArray = new FullyConnectedLayer[getNumTotalLayers()];
        this.layerArray[0] = new FullyConnectedLayer(d, d2, i, i, getNumNeurons(), i4, activationFunction);
        for (int i5 = 1; i5 < getNumHiddenLayers() + 1; i5++) {
            this.layerArray[i5] = new FullyConnectedLayer(d, d2, i, getNumNeurons(), getNumNeurons(), i4, activationFunction);
        }
        this.layerArray[getNumHiddenLayers() + 1] = new FullyConnectedLayer(d, d2, i, getNumNeurons(), 1, i4, new ActivationIdentity());
        int i6 = 0;
        for (int i7 = 0; i7 < this.layerArray.length; i7++) {
            i6 += (this.layerArray[i7].getInputSize() + 1) * this.layerArray[i7].getOutputSize();
        }
        this.numberOfParameters = i6;
    }

    @Override // de.ismll.core.regression.Regression
    public void train(Instances instances) {
        for (int i = 0; i < instances.numInstances(); i++) {
            train(instances.instance(i));
        }
    }

    public void train(Instance instance, double d) {
        double[] values = instance.getValues();
        double[] backward = this.layerArray[getNumTotalLayers() - 1].backward(new double[]{d}, new double[]{instance.target()});
        for (int numTotalLayers = getNumTotalLayers() - 2; numTotalLayers > 0; numTotalLayers--) {
            backward = this.layerArray[numTotalLayers].backward(backward, this.layerArray[numTotalLayers + 1].getW());
        }
        this.layerArray[0].backward(backward, this.layerArray[1].getW());
        this.layerArray[0].updateWeights(values);
        for (int i = 1; i < getNumTotalLayers(); i++) {
            this.layerArray[i].updateWeights(this.layerArray[i - 1].getOutput());
        }
    }

    public void train(Instance instance) {
        train(instance, predict(instance));
    }

    @Override // de.ismll.core.regression.Regression
    public double predict(Instance instance) {
        double[] forward = this.layerArray[0].forward(instance.getValues());
        for (int i = 1; i < this.layerArray.length; i++) {
            forward = this.layerArray[i].forward(forward);
        }
        return forward[0];
    }

    @Override // de.ismll.core.regression.Regression
    public double[] predict(Instances instances) {
        double[] dArr = new double[instances.numInstances()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = predict(instances.instance(i));
        }
        return dArr;
    }

    public double[] predictWithUncertainty(Instance instance) {
        double predict = predict(instance);
        RealVector createRealVector = MatrixUtils.createRealVector(getGradient(instance));
        return new double[]{predict, Math.sqrt(createRealVector.dotProduct(this.inverseHessian.operate(createRealVector)))};
    }

    public double getInitialLearnRate() {
        return this.initialLearnRate;
    }

    public void setInitialLearnRate(double d) {
        this.initialLearnRate = d;
    }

    public int getNumHiddenLayers() {
        return this.numHiddenLayers;
    }

    public void setNumHiddenLayers(int i) {
        this.numHiddenLayers = i;
    }

    public int getNumNeurons() {
        return this.numNeurons;
    }

    public void setNumNeurons(int i) {
        this.numNeurons = i;
    }

    public int getNumTotalLayers() {
        return this.numTotalLayers;
    }

    public void setNumTotalLayers(int i) {
        this.numTotalLayers = i;
    }

    public FullyConnectedLayer[] getLayerArray() {
        return this.layerArray;
    }

    public RealVector getAllParameters() {
        int i = 0;
        double[] dArr = new double[this.numberOfParameters];
        for (int i2 = 0; i2 < this.layerArray.length; i2++) {
            for (int i3 = 0; i3 < this.layerArray[i2].getInputSize(); i3++) {
                for (int i4 = 0; i4 < this.layerArray[i2].getOutputSize(); i4++) {
                    dArr[i] = this.layerArray[i2].getW()[i3][i4];
                    i++;
                }
            }
            for (int i5 = 0; i5 < this.layerArray[i2].getOutputSize(); i5++) {
                dArr[i] = this.layerArray[i2].getB()[i5];
                i++;
            }
        }
        return MatrixUtils.createRealVector(dArr);
    }

    public double[] getGradient(Instance instance) {
        double[] dArr = new double[this.numberOfParameters];
        double predict = predict(instance);
        double[] values = instance.getValues();
        double[] backwardLossLess = this.layerArray[getNumTotalLayers() - 1].backwardLossLess(new double[]{predict});
        for (int numTotalLayers = getNumTotalLayers() - 2; numTotalLayers > 0; numTotalLayers--) {
            backwardLossLess = this.layerArray[numTotalLayers].backward(backwardLossLess, this.layerArray[numTotalLayers + 1].getW());
        }
        this.layerArray[0].backward(backwardLossLess, this.layerArray[1].getW());
        int i = 0;
        double d = this.layerArray[this.layerArray.length - 1].getDelta()[0];
        for (int i2 = 0; i2 < this.layerArray[0].getInputSize(); i2++) {
            for (int i3 = 0; i3 < this.layerArray[0].getOutputSize(); i3++) {
                dArr[i] = this.layerArray[0].getDelta()[i3] * values[i2];
                i++;
            }
        }
        for (int i4 = 0; i4 < this.layerArray[0].getOutputSize(); i4++) {
            dArr[i] = this.layerArray[0].getDelta()[i4];
            i++;
        }
        for (int i5 = 1; i5 < this.layerArray.length; i5++) {
            for (int i6 = 0; i6 < this.layerArray[i5].getInputSize(); i6++) {
                for (int i7 = 0; i7 < this.layerArray[i5].getOutputSize(); i7++) {
                    dArr[i] = this.layerArray[i5].getDelta()[i7] * this.layerArray[i5 - 1].getOutput()[i6];
                    i++;
                }
            }
            for (int i8 = 0; i8 < this.layerArray[i5].getOutputSize(); i8++) {
                dArr[i] = this.layerArray[i5].getDelta()[i8];
                i++;
            }
        }
        return dArr;
    }

    public RealMatrix getOPHessian(Instances instances) {
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(this.numberOfParameters, this.numberOfParameters);
        for (int i = 0; i < instances.numInstances(); i++) {
            createRealMatrix = getOPHessian(instances.instance(i));
        }
        return createRealMatrix;
    }

    public RealMatrix getOPHessian(Instance instance) {
        RealVector createRealVector = MatrixUtils.createRealVector(getGradient(instance));
        return createRealVector.outerProduct(createRealVector);
    }

    public void getOPHessianInverse(Instances instances) {
        RealMatrix createRealIdentityMatrix = MatrixUtils.createRealIdentityMatrix(this.numberOfParameters);
        createRealIdentityMatrix.scalarMultiply(1.0d / this.alpha);
        for (int i = 0; i < instances.numInstances(); i++) {
            RealVector createRealVector = MatrixUtils.createRealVector(getGradient(instances.instance(i)));
            createRealIdentityMatrix.add(createRealIdentityMatrix.multiply(createRealVector.outerProduct(createRealVector).multiply(createRealIdentityMatrix)).scalarMultiply((-1.0d) / (1.0d + createRealVector.dotProduct(createRealIdentityMatrix.operate(createRealVector)))));
        }
        this.inverseHessian = createRealIdentityMatrix;
        DoubleMatrix.eye(this.numberOfParameters).mul(1.0d / this.alpha);
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            new DoubleMatrix(getGradient(instances.instance(i2)));
        }
    }

    public void updateOPHessianInverse(Instance instance) {
        RealVector createRealVector = MatrixUtils.createRealVector(getGradient(instance));
        double dotProduct = 1.0d + createRealVector.dotProduct(this.inverseHessian.operate(createRealVector));
        this.inverseHessian = this.inverseHessian.add(this.inverseHessian.multiply(createRealVector.outerProduct(createRealVector).multiply(this.inverseHessian)).scalarMultiply((-1.0d) / dotProduct));
    }

    public int getNumberOfParameters() {
        return this.numberOfParameters;
    }
}
