package org.neuroph.core.learning;

import java.io.Serializable;
import java.util.Iterator;
import org.encog.engine.network.flat.FlatNetwork;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;

/* loaded from: classes.dex */
public abstract class SupervisedLearning extends IterativeLearning implements Serializable {
    private static final long serialVersionUID = 3;
    private transient int minErrorChangeIterationsCount;
    protected transient double previousEpochError;
    protected transient double totalNetworkError;
    protected transient double totalSquaredErrorSum;
    protected double maxError = 0.01d;
    private transient double minErrorChange = Double.POSITIVE_INFINITY;
    private transient int minErrorChangeIterationsLimit = Integer.MAX_VALUE;
    private boolean batchMode = false;

    /* JADX INFO: Access modifiers changed from: protected */
    public void addToSquaredErrorSum(double[] dArr) {
        double d = FlatNetwork.NO_BIAS_ACTIVATION;
        for (double d2 : dArr) {
            d += d2 * d2 * 0.5d;
        }
        this.totalSquaredErrorSum += d;
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    protected void afterEpochEnd() {
        if (this.batchMode) {
            doBatchWeightsUpdate();
        }
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    protected void beforeEpochStart() {
        this.previousEpochError = this.totalNetworkError;
        this.totalNetworkError = FlatNetwork.NO_BIAS_ACTIVATION;
        this.totalSquaredErrorSum = FlatNetwork.NO_BIAS_ACTIVATION;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] calculateOutputError(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[dArr2.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr3[i] = dArr[i] - dArr2[i];
        }
        return dArr3;
    }

    protected void doBatchWeightsUpdate() {
        for (int layersCount = this.neuralNetwork.getLayersCount() - 1; layersCount > 0; layersCount--) {
            Iterator it = ((Layer) this.neuralNetwork.getLayers().get(layersCount)).getNeurons().iterator();
            while (it.hasNext()) {
                Iterator it2 = ((Neuron) it.next()).getInputConnections().iterator();
                while (it2.hasNext()) {
                    Weight weight = ((Connection) it2.next()).getWeight();
                    weight.value += weight.weightChange;
                    weight.weightChange = FlatNetwork.NO_BIAS_ACTIVATION;
                }
            }
        }
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    public void doLearningEpoch(TrainingSet trainingSet) {
        Iterator it = trainingSet.iterator();
        while (it.hasNext() && !isStopped()) {
            learnPattern((SupervisedTrainingElement) it.next());
        }
        this.totalNetworkError = this.totalSquaredErrorSum / trainingSet.size();
        if (hasReachedStopCondition()) {
            stopLearning();
        }
    }

    protected boolean errorChangeStalled() {
        if (Math.abs(this.previousEpochError - this.totalNetworkError) <= this.minErrorChange) {
            this.minErrorChangeIterationsCount++;
            return this.minErrorChangeIterationsCount >= this.minErrorChangeIterationsLimit;
        }
        this.minErrorChangeIterationsCount = 0;
        return false;
    }

    public double getMaxError() {
        return this.maxError;
    }

    public double getMinErrorChange() {
        return this.minErrorChange;
    }

    public int getMinErrorChangeIterationsCount() {
        return this.minErrorChangeIterationsCount;
    }

    public int getMinErrorChangeIterationsLimit() {
        return this.minErrorChangeIterationsLimit;
    }

    public double getPreviousEpochError() {
        return this.previousEpochError;
    }

    public synchronized double getTotalNetworkError() {
        return this.totalNetworkError;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean hasReachedStopCondition() {
        return this.totalNetworkError < this.maxError || errorChangeStalled();
    }

    public boolean isInBatchMode() {
        return this.batchMode;
    }

    public void learn(TrainingSet trainingSet, double d) {
        this.maxError = d;
        learn(trainingSet);
    }

    public void learn(TrainingSet trainingSet, double d, int i) {
        this.maxError = d;
        setMaxIterations(i);
        learn(trainingSet);
    }

    protected void learnPattern(SupervisedTrainingElement supervisedTrainingElement) {
        this.neuralNetwork.setInput(supervisedTrainingElement.getInput());
        this.neuralNetwork.calculate();
        double[] calculateOutputError = calculateOutputError(supervisedTrainingElement.getDesiredOutput(), this.neuralNetwork.getOutput());
        addToSquaredErrorSum(calculateOutputError);
        updateNetworkWeights(calculateOutputError);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.core.learning.IterativeLearning
    public void onStart() {
        super.onStart();
        this.minErrorChangeIterationsCount = 0;
        this.totalNetworkError = FlatNetwork.NO_BIAS_ACTIVATION;
        this.previousEpochError = FlatNetwork.NO_BIAS_ACTIVATION;
    }

    public void setBatchMode(boolean z) {
        this.batchMode = z;
    }

    public void setMaxError(double d) {
        this.maxError = d;
    }

    public void setMinErrorChange(double d) {
        this.minErrorChange = d;
    }

    public void setMinErrorChangeIterationsLimit(int i) {
        this.minErrorChangeIterationsLimit = i;
    }

    protected abstract void updateNetworkWeights(double[] dArr);
}
