package org.encog.neural.cpn.training;

import org.encog.mathutil.BoundMath;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.cpn.CPN;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;

/* loaded from: input_file:org/encog/neural/cpn/training/TrainInstar.class */
public class TrainInstar extends BasicTraining implements LearningRate {
    private final CPN network;
    private final MLDataSet training;
    private double learningRate;
    private boolean mustInit;

    public TrainInstar(CPN cpn, MLDataSet mLDataSet, double d, boolean z) {
        super(TrainingImplementationType.Iterative);
        this.network = cpn;
        this.training = mLDataSet;
        this.learningRate = d;
        this.mustInit = z;
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override // org.encog.ml.train.MLTrain
    public CPN getMethod() {
        return this.network;
    }

    private void initWeights() {
        if (this.training.getRecordCount() != this.network.getInstarCount()) {
            throw new NeuralNetworkError("If the weights are to be set from the training data, then there must be one instar neuron for each training element.");
        }
        int i = 0;
        for (MLDataPair mLDataPair : this.training) {
            for (int i2 = 0; i2 < this.network.getInputCount(); i2++) {
                this.network.getWeightsInputToInstar().set(i2, i, mLDataPair.getInput().getData(i2));
            }
            i++;
        }
        this.mustInit = false;
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration() {
        if (this.mustInit) {
            initWeights();
        }
        double d = Double.NEGATIVE_INFINITY;
        for (MLDataPair mLDataPair : this.training) {
            int indexOfLargest = EngineArray.indexOfLargest(this.network.computeInstar(mLDataPair.getInput()).getData());
            double d2 = 0.0d;
            for (int i = 0; i < mLDataPair.getInput().size(); i++) {
                double data = mLDataPair.getInput().getData(i) - this.network.getWeightsInputToInstar().get(i, indexOfLargest);
                d2 += data * data;
            }
            double sqrt = BoundMath.sqrt(d2);
            if (sqrt > d) {
                d = sqrt;
            }
            for (int i2 = 0; i2 < this.network.getInputCount(); i2++) {
                this.network.getWeightsInputToInstar().add(i2, indexOfLargest, this.learningRate * (mLDataPair.getInput().getData(i2) - this.network.getWeightsInputToInstar().get(i2, indexOfLargest)));
            }
        }
        setError(d);
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public void setLearningRate(double d) {
        this.learningRate = d;
    }
}
