package de.ismll.core.regression;

import de.ismll.core.Instance;
import de.ismll.core.Instances;
import de.ismll.core.regression.neuralnet.ActivationFunction;
import de.ismll.core.regression.neuralnet.ActivationIdentity;
import de.ismll.core.regression.neuralnet.FactorizationLayer;
import de.ismll.core.regression.neuralnet.FullyConnectedLayer;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:de/ismll/core/regression/FactorizedMultilayerPerceptron.class */
public class FactorizedMultilayerPerceptron implements Regression {
    private double initialLearnRate;
    private int numHiddenLayers;
    private int numNeurons;
    private int numTotalLayers;
    private int numLatentFeatures;
    private RealMatrix inverseHessian;
    private int numberOfParameters;
    private FactorizationLayer inputLayer;
    private FullyConnectedLayer[] layerArray;
    private double alpha = 0.01d;
    private double beta = 1.0d;

    public FactorizedMultilayerPerceptron(double d, double d2, int i, int i2, int i3, int i4, ActivationFunction activationFunction, long j) {
        setInitialLearnRate(d);
        setNumHiddenLayers(i3);
        setNumNeurons(i4);
        setNumTotalLayers(getNumHiddenLayers() + 1);
        this.numLatentFeatures = i;
        int i5 = i3 * i4;
        setLayerArray(new FullyConnectedLayer[getNumTotalLayers()]);
        this.inputLayer = new FactorizationLayer(d, d2, i, i2, i2, i4, i5, activationFunction, j);
        for (int i6 = 0; i6 < this.layerArray.length - 1; i6++) {
            this.layerArray[i6] = new FullyConnectedLayer(d, d2, i2, getNumNeurons(), getNumNeurons(), i5, activationFunction, j);
        }
        this.layerArray[this.layerArray.length - 1] = new FullyConnectedLayer(d, d2, i2, getNumNeurons(), 1, i5, new ActivationIdentity(), j);
        int inputSize = 0 + ((this.inputLayer.getInputSize() + 1) * this.inputLayer.getOutputSize()) + (this.inputLayer.getInputSize() * this.inputLayer.getOutputSize() * this.numLatentFeatures);
        for (int i7 = 0; i7 < this.layerArray.length; i7++) {
            inputSize += (this.layerArray[i7].getInputSize() + 1) * this.layerArray[i7].getOutputSize();
        }
        this.numberOfParameters = inputSize;
    }

    public FactorizedMultilayerPerceptron(double d, double d2, int i, int i2, int i3, int i4, ActivationFunction activationFunction) {
        setInitialLearnRate(d);
        setNumHiddenLayers(i3);
        setNumNeurons(i4);
        setNumTotalLayers(getNumHiddenLayers() + 1);
        this.numLatentFeatures = i;
        int i5 = i3 * i4;
        setLayerArray(new FullyConnectedLayer[getNumTotalLayers()]);
        this.inputLayer = new FactorizationLayer(d, d2, i, i2, i2, i4, i5, activationFunction);
        for (int i6 = 0; i6 < this.layerArray.length - 1; i6++) {
            this.layerArray[i6] = new FullyConnectedLayer(d, d2, i2, getNumNeurons(), getNumNeurons(), i5, activationFunction);
        }
        this.layerArray[this.layerArray.length - 1] = new FullyConnectedLayer(d, d2, i2, getNumNeurons(), 1, i5, new ActivationIdentity());
        int inputSize = 0 + ((this.inputLayer.getInputSize() + 1) * this.inputLayer.getOutputSize()) + (this.inputLayer.getInputSize() * this.inputLayer.getOutputSize() * this.numLatentFeatures);
        for (int i7 = 0; i7 < this.layerArray.length; i7++) {
            inputSize += (this.layerArray[i7].getInputSize() + 1) * this.layerArray[i7].getOutputSize();
        }
        this.numberOfParameters = inputSize;
    }

    @Override // de.ismll.core.regression.Regression
    public void train(Instances instances) {
        for (int i = 0; i < instances.numInstances(); i++) {
            train(instances.instance(i));
        }
    }

    public void train(Instance instance) {
        double[] values = instance.getValues();
        double[] backward = this.layerArray[this.layerArray.length - 1].backward(new double[]{predict(instance)}, new double[]{instance.target()});
        for (int length = this.layerArray.length - 2; length > -1; length--) {
            backward = this.layerArray[length].backward(backward, this.layerArray[length + 1].getW());
        }
        this.inputLayer.backward(backward, this.layerArray[0].getW());
        this.inputLayer.updateWeights(values);
        this.layerArray[0].updateWeights(this.inputLayer.getOutput());
        for (int i = 1; i < getNumTotalLayers(); i++) {
            this.layerArray[i].updateWeights(this.layerArray[i - 1].getOutput());
        }
    }

    @Override // de.ismll.core.regression.Regression
    public double predict(Instance instance) {
        instance.getValues();
        double[] forward = this.inputLayer.forward(instance);
        for (int i = 0; i < this.layerArray.length; i++) {
            forward = this.layerArray[i].forward(forward);
        }
        return forward[0];
    }

    @Override // de.ismll.core.regression.Regression
    public double[] predict(Instances instances) {
        double[] dArr = new double[instances.numInstances()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = predict(instances.instance(i));
        }
        return dArr;
    }

    public double[] predictWithUncertainty(Instance instance) {
        double predict = predict(instance);
        RealVector gradient = getGradient(instance);
        return new double[]{predict, (1.0d / this.beta) + gradient.dotProduct(this.inverseHessian.operate(gradient))};
    }

    public double getInitialLearnRate() {
        return this.initialLearnRate;
    }

    public void setInitialLearnRate(double d) {
        this.initialLearnRate = d;
    }

    public int getNumHiddenLayers() {
        return this.numHiddenLayers;
    }

    public void setNumHiddenLayers(int i) {
        this.numHiddenLayers = i;
    }

    public int getNumNeurons() {
        return this.numNeurons;
    }

    public void setNumNeurons(int i) {
        this.numNeurons = i;
    }

    public int getNumTotalLayers() {
        return this.numTotalLayers;
    }

    public void setNumTotalLayers(int i) {
        this.numTotalLayers = i;
    }

    public FactorizationLayer getInputLayer() {
        return this.inputLayer;
    }

    public void setInputLayer(FactorizationLayer factorizationLayer) {
        this.inputLayer = factorizationLayer;
    }

    public FullyConnectedLayer[] getLayerArray() {
        return this.layerArray;
    }

    public void setLayerArray(FullyConnectedLayer[] fullyConnectedLayerArr) {
        this.layerArray = fullyConnectedLayerArr;
    }

    public RealVector getAllParameters() {
        int i = 0;
        double[] dArr = new double[this.numberOfParameters];
        for (int i2 = 0; i2 < this.inputLayer.getInputSize(); i2++) {
            for (int i3 = 0; i3 < this.inputLayer.getOutputSize(); i3++) {
                dArr[i] = this.inputLayer.getW()[i2][i3];
                i++;
            }
        }
        for (int i4 = 0; i4 < this.inputLayer.getOutputSize(); i4++) {
            dArr[i] = this.inputLayer.getB()[i4];
            i++;
        }
        for (int i5 = 0; i5 < this.numLatentFeatures; i5++) {
            for (int i6 = 0; i6 < this.inputLayer.getOutputSize(); i6++) {
                for (int i7 = 0; i7 < this.inputLayer.getInputSize(); i7++) {
                    dArr[i] = this.inputLayer.getV()[i7][i6][i5];
                    i++;
                }
            }
        }
        for (int i8 = 0; i8 < this.layerArray.length; i8++) {
            for (int i9 = 0; i9 < this.layerArray[i8].getInputSize(); i9++) {
                for (int i10 = 0; i10 < this.layerArray[i8].getOutputSize(); i10++) {
                    dArr[i] = this.layerArray[i8].getW()[i9][i10];
                    i++;
                }
            }
            for (int i11 = 0; i11 < this.layerArray[i8].getOutputSize(); i11++) {
                dArr[i] = this.layerArray[i8].getB()[i11];
                i++;
            }
        }
        return MatrixUtils.createRealVector(dArr);
    }

    public RealVector getGradient(Instance instance) {
        double[] dArr = new double[this.numberOfParameters];
        double predict = predict(instance);
        double[] values = instance.getValues();
        double target = predict - instance.target();
        double[] backwardLossLess = this.layerArray[getNumTotalLayers() - 1].backwardLossLess(new double[]{predict});
        for (int numTotalLayers = getNumTotalLayers() - 2; numTotalLayers > 0; numTotalLayers--) {
            backwardLossLess = this.layerArray[numTotalLayers].backward(backwardLossLess, this.layerArray[numTotalLayers + 1].getW());
        }
        this.inputLayer.backward(this.layerArray[0].backward(backwardLossLess, this.layerArray[1].getW()), this.layerArray[0].getW());
        int i = 0;
        for (int i2 = 0; i2 < this.inputLayer.getInputSize(); i2++) {
            for (int i3 = 0; i3 < this.inputLayer.getOutputSize(); i3++) {
                dArr[i] = this.inputLayer.getDelta()[i3] * values[i2];
                i++;
            }
        }
        for (int i4 = 0; i4 < this.inputLayer.getOutputSize(); i4++) {
            dArr[i] = this.inputLayer.getDelta()[i4];
            i++;
        }
        for (int i5 = 0; i5 < this.numLatentFeatures; i5++) {
            for (int i6 = 0; i6 < this.inputLayer.getOutputSize(); i6++) {
                double preComputeSum = this.inputLayer.preComputeSum(values, i5, i6);
                for (int i7 = 0; i7 < this.inputLayer.getInputSize(); i7++) {
                    dArr[i] = this.inputLayer.getDelta()[i6] * ((values[i7] * preComputeSum) - ((this.inputLayer.getV()[i7][i6][i5] * values[i7]) * values[i7]));
                    i++;
                }
            }
        }
        for (int i8 = 0; i8 < this.layerArray[0].getInputSize(); i8++) {
            for (int i9 = 0; i9 < this.layerArray[0].getOutputSize(); i9++) {
                dArr[i] = this.layerArray[0].getDelta()[i9] * this.inputLayer.getOutput()[i8];
                i++;
            }
        }
        for (int i10 = 0; i10 < this.layerArray[0].getOutputSize(); i10++) {
            dArr[i] = this.layerArray[0].getDelta()[i10];
            i++;
        }
        for (int i11 = 1; i11 < this.layerArray.length; i11++) {
            for (int i12 = 0; i12 < this.layerArray[i11].getInputSize(); i12++) {
                for (int i13 = 0; i13 < this.layerArray[i11].getOutputSize(); i13++) {
                    dArr[i] = this.layerArray[i11].getDelta()[i13] * this.layerArray[i11 - 1].getOutput()[i12];
                    i++;
                }
            }
            for (int i14 = 0; i14 < this.layerArray[i11].getOutputSize(); i14++) {
                dArr[i] = this.layerArray[i11].getDelta()[i14];
                i++;
            }
        }
        return MatrixUtils.createRealVector(dArr);
    }

    public RealMatrix getOPHessian(Instances instances) {
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(this.numberOfParameters, this.numberOfParameters);
        for (int i = 0; i < instances.numInstances(); i++) {
            createRealMatrix = getOPHessian(instances.instance(i));
        }
        return createRealMatrix;
    }

    public RealMatrix getOPHessian(Instance instance) {
        RealVector gradient = getGradient(instance);
        return gradient.outerProduct(gradient);
    }

    public void getOPHessianInverse(Instances instances) {
        RealMatrix createRealIdentityMatrix = MatrixUtils.createRealIdentityMatrix(this.numberOfParameters);
        createRealIdentityMatrix.scalarMultiply(1.0d / this.alpha);
        System.out.println(instances.numInstances());
        for (int i = 0; i < instances.numInstances(); i++) {
            RealVector gradient = getGradient(instances.instance(i));
            createRealIdentityMatrix.add(createRealIdentityMatrix.multiply(gradient.outerProduct(gradient).multiply(createRealIdentityMatrix)).scalarMultiply((-1.0d) / (1.0d + gradient.dotProduct(createRealIdentityMatrix.operate(gradient)))));
            System.out.println("summand added");
        }
        this.inverseHessian = createRealIdentityMatrix;
    }

    public void updateOPHessianInverse(Instance instance) {
        RealVector gradient = getGradient(instance);
        double dotProduct = 1.0d + gradient.dotProduct(this.inverseHessian.operate(gradient));
        this.inverseHessian = this.inverseHessian.add(this.inverseHessian.multiply(gradient.outerProduct(gradient).multiply(this.inverseHessian)).scalarMultiply((-1.0d) / dotProduct));
    }

    public int getNumberOfParameters() {
        return this.numberOfParameters;
    }
}
