package de.ismll.core.regression;

import de.ismll.core.Instance;
import de.ismll.core.Instances;
import java.util.Random;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:de/ismll/core/regression/FactorizationMachineRegression.class */
public class FactorizationMachineRegression implements Regression {
    private double w0;
    public RealVector w;
    public RealMatrix v;
    private int numAttributes;
    private double reg0;
    private double gradientForW0History;
    private double[] gradientForWHistory;
    private double[][] gradientForVHistory;
    private double lastPrediction;
    private int numEpochs = 100;
    private double maxTarget = Double.MAX_VALUE;
    private double minTarget = -1.7976931348623157E308d;
    private int numFactor = 10;
    private double regW = 0.01d;
    private double regV = 0.01d;
    private boolean useW0 = true;
    private boolean useW = true;
    private boolean useV = true;
    private final int MODE_SGD = 0;
    private final int MODE_ADAGRAD = 1;
    private int optimizationMode = 0;
    private double adagradEpsilon = 1.0E-5d;
    private double init_stdev = 0.01d;
    public double init_mean = 0.0d;
    private double learnRate = 0.1d;

    public FactorizationMachineRegression() {
    }

    public FactorizationMachineRegression(double d, double d2, int i, int i2, double d3, double d4, double d5, long j) {
        setOptimizationMode(1);
        setLearnRate(d2);
        setNumAttributes(i2);
        setW0(0.0d);
        setNumFactor(i);
        setReg0(0.0d);
        setRegV(d);
        setRegW(d);
        setInit_stdev(d3);
        double[] dArr = new double[getNumAttributes()];
        double[][] dArr2 = new double[getNumAttributes()][getNumFactor()];
        Random random = new Random(j);
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = this.init_mean + (random.nextGaussian() * getInit_stdev());
        }
        for (int i4 = 0; i4 < getNumAttributes(); i4++) {
            for (int i5 = 0; i5 < getNumFactor(); i5++) {
                dArr2[i4][i5] = this.init_mean + (random.nextGaussian() * getInit_stdev());
            }
        }
        this.w = MatrixUtils.createRealVector(dArr);
        this.v = MatrixUtils.createRealMatrix(dArr2);
        this.gradientForWHistory = new double[dArr.length];
        this.gradientForVHistory = new double[getNumAttributes()][getNumFactor()];
        this.gradientForW0History = 0.0d;
        setMaxTarget(d4);
        setMinTarget(d5);
    }

    public FactorizationMachineRegression(double d, double d2, int i, int i2, double d3, double d4, double d5) {
        setOptimizationMode(1);
        setLearnRate(d2);
        setNumAttributes(i2);
        setW0(0.0d);
        setNumFactor(i);
        setReg0(0.0d);
        setRegV(d);
        setRegW(d);
        setInit_stdev(d3);
        double[] dArr = new double[getNumAttributes()];
        double[][] dArr2 = new double[getNumAttributes()][getNumFactor()];
        Random random = new Random();
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = this.init_mean + (random.nextGaussian() * getInit_stdev());
        }
        for (int i4 = 0; i4 < getNumAttributes(); i4++) {
            for (int i5 = 0; i5 < getNumFactor(); i5++) {
                dArr2[i4][i5] = this.init_mean + (random.nextGaussian() * getInit_stdev());
            }
        }
        this.w = MatrixUtils.createRealVector(dArr);
        this.v = MatrixUtils.createRealMatrix(dArr2);
        this.gradientForWHistory = new double[dArr.length];
        this.gradientForVHistory = new double[getNumAttributes()][getNumFactor()];
        this.gradientForW0History = 0.0d;
        setMaxTarget(d4);
        setMinTarget(d5);
    }

    public void initializeModel(int i) {
        setNumAttributes(i);
        setW0(0.0d);
        double[] dArr = new double[getNumAttributes()];
        double[][] dArr2 = new double[getNumAttributes()][getNumFactor()];
        Random random = new Random();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = this.init_mean + (random.nextGaussian() * getInit_stdev());
        }
        for (int i3 = 0; i3 < getNumAttributes(); i3++) {
            for (int i4 = 0; i4 < getNumFactor(); i4++) {
                dArr2[i3][i4] = this.init_mean + (random.nextGaussian() * getInit_stdev());
            }
        }
        this.w = MatrixUtils.createRealVector(dArr);
        this.v = MatrixUtils.createRealMatrix(dArr2);
        this.gradientForWHistory = new double[dArr.length];
        this.gradientForVHistory = new double[getNumAttributes()][getNumFactor()];
        this.gradientForW0History = 0.0d;
    }

    @Override // de.ismll.core.regression.Regression
    public double[] predict(Instances instances) {
        double[] dArr = new double[instances.numInstances()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = predict(instances.instance(i));
        }
        return dArr;
    }

    @Override // de.ismll.core.regression.Regression
    public double predict(Instance instance) {
        double w0 = isUseW0() ? 0.0d + getW0() : 0.0d;
        int[] keys = instance.getKeys();
        double[] values = instance.getValues();
        if (isUseW()) {
            for (int i = 0; i < keys.length; i++) {
                w0 += this.w.getEntry(keys[i]) * values[i];
            }
        }
        if (isUseV()) {
            for (int i2 = 0; i2 < getNumFactor(); i2++) {
                double d = 0.0d;
                double d2 = 0.0d;
                for (int i3 = 0; i3 < keys.length; i3++) {
                    double entry = this.v.getEntry(keys[i3], i2) * values[i3];
                    d += entry;
                    d2 += entry * entry;
                }
                w0 += 0.5d * ((d * d) - d2);
            }
        }
        if (w0 > this.maxTarget) {
            w0 = this.maxTarget;
        }
        if (w0 < this.minTarget) {
            w0 = this.minTarget;
        }
        this.lastPrediction = w0;
        return w0;
    }

    public double preComputeSum(Instance instance, int i) {
        double d = 0.0d;
        int[] keys = instance.getKeys();
        double[] values = instance.getValues();
        for (int i2 = 0; i2 < keys.length; i2++) {
            d += this.v.getEntry(keys[i2], i) * values[i2];
        }
        return d;
    }

    public void doSGD(Instance instance, double d) {
        switch (this.optimizationMode) {
            case 0:
                SGD(instance, d);
                return;
            case 1:
                SGDAdagrad(instance, d);
                return;
            default:
                System.err.println("Optimization Mode not set, exiting...");
                System.exit(1);
                return;
        }
    }

    @Override // de.ismll.core.regression.Regression
    public void train(Instances instances) {
        for (int i = 0; i < instances.numInstances(); i++) {
            train(instances.instance(i));
        }
    }

    public void train(Instance instance) {
        train(instance, predict(instance));
    }

    public void train(Instance instance, double d) {
        doSGD(instance, 2.0d * (d - instance.target()));
    }

    public void SGDAdagrad(Instance instance, double d) {
        int[] keys = instance.getKeys();
        double[] values = instance.getValues();
        if (isUseW0()) {
            this.gradientForW0History += d * d;
            double w0 = getW0() - ((getLearnRate() / Math.sqrt(this.gradientForW0History)) * d);
            if (Double.isNaN(w0) || Double.isInfinite(w0)) {
                System.out.println("Update of Bias is about to be NaN...");
                System.exit(1);
            }
            setW0(w0);
        }
        if (isUseW()) {
            for (int i = 0; i < keys.length; i++) {
                int i2 = keys[i];
                double regW = (d * values[i]) + (getRegW() * this.w.getEntry(i2));
                double[] dArr = this.gradientForWHistory;
                dArr[i2] = dArr[i2] + (regW * regW);
                double entry = this.w.getEntry(i2) - ((this.gradientForWHistory[i2] == 0.0d ? getLearnRate() : getLearnRate() / Math.sqrt(this.gradientForWHistory[i2])) * regW);
                if (Double.isNaN(entry) || Double.isInfinite(entry)) {
                    System.out.println("Update of a Regression Term is about to be NaN...");
                    System.exit(1);
                }
                this.w.setEntry(i2, entry);
            }
        }
        if (isUseV()) {
            for (int i3 = 0; i3 < getNumFactor(); i3++) {
                double preComputeSum = preComputeSum(instance, i3);
                for (int i4 = 0; i4 < keys.length; i4++) {
                    int i5 = keys[i4];
                    double entry2 = (d * ((values[i4] * preComputeSum) - ((this.v.getEntry(i5, i3) * values[i4]) * values[i4]))) + (getRegV() * this.v.getEntry(i5, i3));
                    double[] dArr2 = this.gradientForVHistory[i5];
                    int i6 = i3;
                    dArr2[i6] = dArr2[i6] + (entry2 * entry2);
                    double entry3 = this.v.getEntry(i5, i3) - ((this.gradientForVHistory[i5][i3] == 0.0d ? getLearnRate() : getLearnRate() / Math.sqrt(this.gradientForVHistory[i5][i3])) * entry2);
                    if (Double.isNaN(entry3) || Double.isInfinite(entry3)) {
                        System.out.println("Update of a latent Feature is about to be NaN...");
                        System.exit(1);
                    }
                    this.v.setEntry(i5, i3, entry3);
                }
            }
        }
    }

    public void SGD(Instance instance, double d) {
        int[] keys = instance.getKeys();
        double[] values = instance.getValues();
        if (isUseW0()) {
            double w0 = getW0() - (getLearnRate() * d);
            if (Double.isNaN(w0) || Double.isInfinite(w0)) {
                System.out.println("Update of Bias is about to be NaN...");
                System.exit(1);
            }
            setW0(w0);
        }
        if (isUseW()) {
            for (int i = 0; i < keys.length; i++) {
                int i2 = keys[i];
                double entry = this.w.getEntry(i2) - (getLearnRate() * ((d * values[i]) + (getRegW() * this.w.getEntry(i2))));
                if (Double.isNaN(entry) || Double.isInfinite(entry)) {
                    System.out.println("Update of a Regression Term is about to be NaN...");
                    System.exit(1);
                }
                this.w.setEntry(i2, entry);
            }
        }
        if (isUseV()) {
            for (int i3 = 0; i3 < getNumFactor(); i3++) {
                double preComputeSum = preComputeSum(instance, i3);
                for (int i4 = 0; i4 < keys.length; i4++) {
                    int i5 = keys[i4];
                    double entry2 = this.v.getEntry(i5, i3) - (getLearnRate() * ((d * ((values[i4] * preComputeSum) - ((this.v.getEntry(i5, i3) * values[i4]) * values[i4]))) + (getRegV() * this.v.getEntry(i5, i3))));
                    if (Double.isNaN(entry2) || Double.isInfinite(entry2)) {
                        System.out.println("Update of a latent Feature is about to be NaN...");
                        System.exit(1);
                    }
                    this.v.setEntry(i5, i3, entry2);
                }
            }
        }
    }

    public double[] predictWithUncertainty(Instance instance) {
        return new double[]{predict(instance), 0.0d};
    }

    public double getW0() {
        return this.w0;
    }

    public void setW0(double d) {
        this.w0 = d;
    }

    public int getNumAttributes() {
        return this.numAttributes;
    }

    public void setNumAttributes(int i) {
        this.numAttributes = i;
    }

    public int getNumFactor() {
        return this.numFactor;
    }

    public void setNumFactor(int i) {
        this.numFactor = i;
    }

    public double getReg0() {
        return this.reg0;
    }

    public void setReg0(double d) {
        this.reg0 = d;
    }

    public double getRegW() {
        return this.regW;
    }

    public void setRegW(double d) {
        this.regW = d;
    }

    public double getRegV() {
        return this.regV;
    }

    public void setRegV(double d) {
        this.regV = d;
    }

    public boolean isUseW0() {
        return this.useW0;
    }

    public void setUseW0(boolean z) {
        this.useW0 = z;
    }

    public boolean isUseW() {
        return this.useW;
    }

    public void setUseW(boolean z) {
        this.useW = z;
    }

    public boolean isUseV() {
        return this.useV;
    }

    public void setUseV(boolean z) {
        this.useV = z;
    }

    public double getInit_stdev() {
        return this.init_stdev;
    }

    public void setInit_stdev(double d) {
        this.init_stdev = d;
    }

    public double getLearnRate() {
        return this.learnRate;
    }

    public void setLearnRate(double d) {
        this.learnRate = d;
    }

    public int getNumEpochs() {
        return this.numEpochs;
    }

    public void setNumEpochs(int i) {
        this.numEpochs = i;
    }

    public double getMaxTarget() {
        return this.maxTarget;
    }

    public void setMaxTarget(double d) {
        this.maxTarget = d;
    }

    public double getMinTarget() {
        return this.minTarget;
    }

    public void setMinTarget(double d) {
        this.minTarget = d;
    }

    public int getOptimizationMode() {
        return this.optimizationMode;
    }

    public void setOptimizationMode(int i) {
        this.optimizationMode = i;
    }

    public double getLastPrediction() {
        return this.lastPrediction;
    }

    public double getAdagradEpsilon() {
        return this.adagradEpsilon;
    }

    public void setAdagradEpsilon(double d) {
        this.adagradEpsilon = d;
    }
}
