package org.encog.mathutil.matrices.hessian;

import org.encog.mathutil.matrices.Matrix;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;
import org.encog.util.EngineArray;

/* loaded from: input_file:org/encog/mathutil/matrices/hessian/BasicHessian.class */
public abstract class BasicHessian implements ComputeHessian {
    protected MLDataSet training;
    protected BasicNetwork network;
    protected double sse;
    protected double[] gradients;
    protected Matrix hessianMatrix;
    protected double[][] hessian;
    protected FlatNetwork flat;

    @Override // org.encog.mathutil.matrices.hessian.ComputeHessian
    public void init(BasicNetwork basicNetwork, MLDataSet mLDataSet) {
        int length = basicNetwork.getStructure().getFlat().getWeights().length;
        this.flat = basicNetwork.getFlat();
        this.training = mLDataSet;
        this.network = basicNetwork;
        this.gradients = new double[length];
        this.hessianMatrix = new Matrix(length, length);
        this.hessian = this.hessianMatrix.getData();
    }

    @Override // org.encog.mathutil.matrices.hessian.ComputeHessian
    public double[] getGradients() {
        return this.gradients;
    }

    @Override // org.encog.mathutil.matrices.hessian.ComputeHessian
    public Matrix getHessianMatrix() {
        return this.hessianMatrix;
    }

    @Override // org.encog.mathutil.matrices.hessian.ComputeHessian
    public double[][] getHessian() {
        return this.hessian;
    }

    @Override // org.encog.mathutil.matrices.hessian.ComputeHessian
    public void clear() {
        EngineArray.fill(this.gradients, 0.0d);
        this.hessianMatrix.clear();
    }

    @Override // org.encog.mathutil.matrices.hessian.ComputeHessian
    public double getSSE() {
        return this.sse;
    }

    public void updateHessian(double[] dArr) {
        int length = this.network.getFlat().getWeights().length;
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                double[] dArr2 = this.hessian[i];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + (dArr[i] * dArr[i2]);
            }
        }
    }
}
