package de.ismll.core.regression.neuralnet;

import de.ismll.core.Instance;
import java.util.Random;

/* loaded from: input_file:de/ismll/core/regression/neuralnet/FactorizationLayer.class */
public class FactorizationLayer {
    private double[] b;
    private double[][] w;
    private double[][][] v;
    private int numLatentFeatures;
    private double learnRate;
    private double momentum;
    private double[][][] velocityLFHistory;
    private double[][] velocityHistory;
    private double initStdev;
    private double adaGradEpsilon;
    private ActivationFunction activationFunction;
    private Random rand;
    private double[] output;
    private double[] delta;
    private int inputSize;
    private int outputSize;
    private int numAttributesOfData;
    private int numTotalHiddenNeurons;
    private double[][] adaGradDenominator;
    private double[] adaGradDenominatorBias;
    private int optimizationMode;
    private int initializationMode;
    public final int MODE_SGD = 0;
    public final int MODE_MOM = 1;
    public final int MODE_ADAGRAD = 2;
    public final int MODE_GAUSSIAN = 0;
    public final int MODE_WIDROW = 1;
    public final int MODE_FANIN = 2;

    public FactorizationLayer(double d, double d2, int i, int i2, int i3, int i4, int i5, ActivationFunction activationFunction) {
        this.learnRate = 0.01d;
        this.momentum = 0.001d;
        this.initStdev = 1.0E-4d;
        this.adaGradEpsilon = 0.001d;
        this.optimizationMode = 0;
        this.initializationMode = 0;
        this.MODE_SGD = 0;
        this.MODE_MOM = 1;
        this.MODE_ADAGRAD = 2;
        this.MODE_GAUSSIAN = 0;
        this.MODE_WIDROW = 1;
        this.MODE_FANIN = 2;
        this.rand = new Random();
        this.numAttributesOfData = i2;
        setOptimizationMode(1);
        setInitializationMode(1);
        setNumTotalHiddenNeurons(i5);
        this.activationFunction = activationFunction;
        this.numLatentFeatures = i;
        this.learnRate = d;
        this.momentum = d2;
        setInputSize(i3);
        setOutputSize(i4);
        this.velocityHistory = new double[i3][i4];
        this.velocityLFHistory = new double[i3][i4][i];
        this.w = new double[i3][i4];
        this.v = new double[i3][i4][i];
        this.b = new double[i4];
        this.adaGradDenominator = new double[i3][i4];
        this.adaGradDenominatorBias = new double[i4];
        this.output = new double[i4];
        this.delta = new double[i4];
        for (int i6 = 0; i6 < this.adaGradDenominator.length; i6++) {
            for (int i7 = 0; i7 < this.adaGradDenominator[i6].length; i7++) {
                this.adaGradDenominator[i6][i7] = this.adaGradEpsilon;
            }
        }
        for (int i8 = 0; i8 < this.adaGradDenominatorBias.length; i8++) {
            this.adaGradDenominatorBias[i8] = this.adaGradEpsilon;
        }
        initializeWeights(this.initializationMode);
    }

    public FactorizationLayer(double d, double d2, int i, int i2, int i3, int i4, int i5, ActivationFunction activationFunction, long j) {
        this.learnRate = 0.01d;
        this.momentum = 0.001d;
        this.initStdev = 1.0E-4d;
        this.adaGradEpsilon = 0.001d;
        this.optimizationMode = 0;
        this.initializationMode = 0;
        this.MODE_SGD = 0;
        this.MODE_MOM = 1;
        this.MODE_ADAGRAD = 2;
        this.MODE_GAUSSIAN = 0;
        this.MODE_WIDROW = 1;
        this.MODE_FANIN = 2;
        this.rand = new Random(j);
        this.numAttributesOfData = i2;
        setOptimizationMode(1);
        setInitializationMode(1);
        setNumTotalHiddenNeurons(i5);
        this.activationFunction = activationFunction;
        this.numLatentFeatures = i;
        this.learnRate = d;
        this.momentum = d2;
        setInputSize(i3);
        setOutputSize(i4);
        this.velocityHistory = new double[i3][i4];
        this.velocityLFHistory = new double[i3][i4][i];
        this.w = new double[i3][i4];
        this.v = new double[i3][i4][i];
        this.b = new double[i4];
        this.adaGradDenominator = new double[i3][i4];
        this.adaGradDenominatorBias = new double[i4];
        this.output = new double[i4];
        this.delta = new double[i4];
        for (int i6 = 0; i6 < this.adaGradDenominator.length; i6++) {
            for (int i7 = 0; i7 < this.adaGradDenominator[i6].length; i7++) {
                this.adaGradDenominator[i6][i7] = this.adaGradEpsilon;
            }
        }
        for (int i8 = 0; i8 < this.adaGradDenominatorBias.length; i8++) {
            this.adaGradDenominatorBias[i8] = this.adaGradEpsilon;
        }
        initializeWeights(this.initializationMode);
    }

    public void setLearnRate(double d) {
        this.learnRate = d;
    }

    public void initializeWeights(int i) {
        switch (i) {
            case 0:
                for (int i2 = 0; i2 < this.w.length; i2++) {
                    for (int i3 = 0; i3 < this.w[i2].length; i3++) {
                        this.w[i2][i3] = this.rand.nextGaussian() * this.initStdev;
                    }
                }
                for (int i4 = 0; i4 < getInputSize(); i4++) {
                    for (int i5 = 0; i5 < getOutputSize(); i5++) {
                        for (int i6 = 0; i6 < this.numLatentFeatures; i6++) {
                            this.v[i4][i5][i6] = this.rand.nextGaussian() * this.initStdev;
                        }
                    }
                }
                for (int i7 = 0; i7 < this.b.length; i7++) {
                    this.b[i7] = this.rand.nextGaussian() * this.initStdev;
                }
                return;
            case 1:
                for (int i8 = 0; i8 < this.w.length; i8++) {
                    for (int i9 = 0; i9 < this.w[i8].length; i9++) {
                        this.w[i8][i9] = (this.rand.nextDouble() - 0.5d) * 2.0d;
                    }
                }
                for (int i10 = 0; i10 < this.b.length; i10++) {
                    this.b[i10] = (this.rand.nextDouble() - 0.5d) * 2.0d;
                }
                double pow = 0.7d * Math.pow(this.numTotalHiddenNeurons, 1.0d / this.numAttributesOfData);
                double d = 0.0d;
                for (int i11 = 0; i11 < this.w.length; i11++) {
                    for (int i12 = 0; i12 < this.w[i11].length; i12++) {
                        d += this.w[i11][i12] * this.w[i11][i12];
                    }
                }
                for (int i13 = 0; i13 < this.b.length; i13++) {
                    d += this.b[i13] * this.b[i13];
                }
                double sqrt = Math.sqrt(d);
                for (int i14 = 0; i14 < this.w.length; i14++) {
                    for (int i15 = 0; i15 < this.w[i14].length; i15++) {
                        this.w[i14][i15] = (pow * this.w[i14][i15]) / sqrt;
                    }
                }
                for (int i16 = 0; i16 < this.b.length; i16++) {
                    this.b[i16] = (pow * this.b[i16]) / sqrt;
                }
                for (int i17 = 0; i17 < getInputSize(); i17++) {
                    for (int i18 = 0; i18 < getOutputSize(); i18++) {
                        for (int i19 = 0; i19 < this.numLatentFeatures; i19++) {
                            this.v[i17][i18][i19] = this.rand.nextGaussian() * this.initStdev;
                        }
                    }
                }
                return;
            default:
                return;
        }
    }

    private double computeDelta(double[] dArr, double[][] dArr2, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            d += dArr2[i][i2] * dArr[i2];
        }
        return d * this.activationFunction.computeDerivative(this.output[i]);
    }

    public double[] backward(double[] dArr, double[][] dArr2) {
        for (int i = 0; i < getOutputSize(); i++) {
            this.delta[i] = computeDelta(dArr, dArr2, i);
        }
        return this.delta;
    }

    public double[] backward(double[] dArr, double[] dArr2) {
        for (int i = 0; i < getOutputSize(); i++) {
            this.delta[i] = (dArr[i] - dArr2[i]) * this.activationFunction.computeDerivative(dArr[i]);
        }
        return this.delta;
    }

    public void updateWeights(double[] dArr) {
        switch (this.optimizationMode) {
            case 0:
                updateWeightsSGD(dArr);
                return;
            case 1:
                updateWeightsMomentum(dArr);
                return;
            case 2:
                updateWeightsAdagrad(dArr);
                return;
            default:
                System.err.println("No update mode defined, exiting....");
                System.exit(1);
                return;
        }
    }

    public double preComputeSum(double[] dArr, int i, int i2) {
        double d = 0.0d;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            d += this.v[i3][i2][i] * dArr[i3];
        }
        return d;
    }

    public void updateWeightsSGD(double[] dArr) {
        for (int i = 0; i < getInputSize(); i++) {
            for (int i2 = 0; i2 < getOutputSize(); i2++) {
                double d = this.delta[i2] * dArr[i];
                double[] dArr2 = this.w[i];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] - (this.learnRate * d);
            }
        }
        for (int i4 = 0; i4 < this.numLatentFeatures; i4++) {
            for (int i5 = 0; i5 < getOutputSize(); i5++) {
                double preComputeSum = preComputeSum(dArr, i4, i5);
                for (int i6 = 0; i6 < dArr.length; i6++) {
                    double d2 = this.delta[i5] * ((dArr[i6] * preComputeSum) - ((this.v[i6][i5][i4] * dArr[i6]) * dArr[i6]));
                    double[] dArr3 = this.v[i6][i5];
                    int i7 = i4;
                    dArr3[i7] = dArr3[i7] - (this.learnRate * d2);
                }
            }
        }
        for (int i8 = 0; i8 < getOutputSize(); i8++) {
            double[] dArr4 = this.b;
            int i9 = i8;
            dArr4[i9] = dArr4[i9] - (this.learnRate * this.delta[i8]);
        }
    }

    public void updateWeightsMomentum(double[] dArr) {
        for (int i = 0; i < getInputSize(); i++) {
            for (int i2 = 0; i2 < getOutputSize(); i2++) {
                double d = (this.momentum * this.velocityHistory[i][i2]) - ((this.learnRate * this.delta[i2]) * dArr[i]);
                double[] dArr2 = this.w[i];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + d;
                this.velocityHistory[i][i2] = d;
            }
        }
        for (int i4 = 0; i4 < this.numLatentFeatures; i4++) {
            for (int i5 = 0; i5 < getOutputSize(); i5++) {
                double preComputeSum = preComputeSum(dArr, i4, i5);
                for (int i6 = 0; i6 < dArr.length; i6++) {
                    double d2 = (this.momentum * this.velocityLFHistory[i6][i5][i4]) - ((this.learnRate * this.delta[i5]) * ((dArr[i6] * preComputeSum) - ((this.v[i6][i5][i4] * dArr[i6]) * dArr[i6])));
                    double[] dArr3 = this.v[i6][i5];
                    int i7 = i4;
                    dArr3[i7] = dArr3[i7] + d2;
                    this.velocityLFHistory[i6][i5][i4] = d2;
                }
            }
        }
        for (int i8 = 0; i8 < getOutputSize(); i8++) {
            double[] dArr4 = this.b;
            int i9 = i8;
            dArr4[i9] = dArr4[i9] - (this.learnRate * this.delta[i8]);
        }
    }

    public void updateWeightsAdagrad(double[] dArr) {
        for (int i = 0; i < getInputSize(); i++) {
            for (int i2 = 0; i2 < getOutputSize(); i2++) {
                double d = this.delta[i2] * dArr[i];
                double[] dArr2 = this.w[i];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] - ((this.learnRate / Math.sqrt(this.adaGradDenominator[i][i2])) * d);
                double[] dArr3 = this.adaGradDenominator[i];
                int i4 = i2;
                dArr3[i4] = dArr3[i4] + (d * d);
            }
        }
        for (int i5 = 0; i5 < getOutputSize(); i5++) {
            double[] dArr4 = this.b;
            int i6 = i5;
            dArr4[i6] = dArr4[i6] - ((this.learnRate / Math.sqrt(this.adaGradDenominatorBias[i5])) * this.delta[i5]);
            double[] dArr5 = this.adaGradDenominatorBias;
            int i7 = i5;
            dArr5[i7] = dArr5[i7] + (this.delta[i5] * this.delta[i5]);
        }
    }

    public double[] forward(Instance instance) {
        for (int i = 0; i < getOutputSize(); i++) {
            this.output[i] = computeOutput(instance, i);
        }
        return this.output;
    }

    protected double computeOutput(Instance instance, int i) {
        double[] values = instance.getValues();
        int[] keys = instance.getKeys();
        double d = this.b[i];
        for (int i2 = 0; i2 < getInputSize(); i2++) {
            d += this.w[i2][i] * values[i2];
        }
        double d2 = 0.0d;
        for (int i3 = 0; i3 < this.numLatentFeatures; i3++) {
            double d3 = 0.0d;
            double d4 = 0.0d;
            for (int i4 : keys) {
                double d5 = this.v[i4][i][i3] * values[i4];
                d3 += d5;
                d4 += d5 * d5;
            }
            d2 += (d3 * d3) - d4;
        }
        return this.activationFunction.computeOutput(d + (0.5d * d2));
    }

    public double[][] getW() {
        return this.w;
    }

    public double[] getB() {
        return this.b;
    }

    public void setB(int i, double d) {
        this.b[i] = d;
    }

    public void setW(int i, int i2, double d) {
        this.w[i][i2] = d;
    }

    public double[] getOutput() {
        return this.output;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setMomentum(double d) {
        this.momentum = d;
    }

    public double getInitStdev() {
        return this.initStdev;
    }

    public void setInitStdev(double d) {
        this.initStdev = d;
    }

    public ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    public double[] getDelta() {
        return this.delta;
    }

    public int getOptimizationMode() {
        return this.optimizationMode;
    }

    public void setOptimizationMode(int i) {
        this.optimizationMode = i;
    }

    public int getInitializationMode() {
        return this.initializationMode;
    }

    public void setInitializationMode(int i) {
        this.initializationMode = i;
    }

    public int getNumTotalHiddenNeurons() {
        return this.numTotalHiddenNeurons;
    }

    public void setNumTotalHiddenNeurons(int i) {
        this.numTotalHiddenNeurons = i;
    }

    public int getNumLatentFeatures() {
        return this.numLatentFeatures;
    }

    public void setNumLatentFeatures(int i) {
        this.numLatentFeatures = i;
    }

    public double[][][] getV() {
        return this.v;
    }

    public void setV(double[][][] dArr) {
        this.v = dArr;
    }

    public int getInputSize() {
        return this.inputSize;
    }

    public void setInputSize(int i) {
        this.inputSize = i;
    }

    public int getOutputSize() {
        return this.outputSize;
    }

    public void setOutputSize(int i) {
        this.outputSize = i;
    }
}
