package org.neuroph.nnet.flat;

import java.io.Serializable;
import org.encog.engine.EncogEngineError;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.TrainFlatNetwork;
import org.encog.engine.network.train.prop.TrainFlatNetworkBackPropagation;
import org.encog.engine.network.train.prop.TrainFlatNetworkManhattan;
import org.encog.engine.network.train.prop.TrainFlatNetworkResilient;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.core.learning.TrainingSet;

/* loaded from: classes.dex */
public class FlatNetworkLearning extends SupervisedLearning implements Serializable {
    private static final long serialVersionUID = 1;
    private FlatNetwork flat;
    private FlatLearningType lastLearningType;
    private transient EngineIndexableSet lastTrainingSet;
    private double learningRate;
    private FlatLearningType learningType;
    private double momentum;
    private int numThreads;
    private transient TrainFlatNetwork training;

    public FlatNetworkLearning(FlatNetwork flatNetwork) {
        init(flatNetwork);
    }

    public FlatNetworkLearning(NeuralNetwork neuralNetwork) {
        FlatNetwork flatNetwork = ((FlatNetworkPlugin) neuralNetwork.getPlugin(FlatNetworkPlugin.class)).getFlatNetwork();
        if (this.flat == null) {
            throw new EncogEngineError("This learning rule only works with a network that has a FlatNetworkPlugin attached.");
        }
        init(flatNetwork);
    }

    private void init(FlatNetwork flatNetwork) {
        this.flat = flatNetwork;
        this.learningType = FlatLearningType.ResilientPropagation;
        this.learningRate = 0.7d;
        this.momentum = 0.3d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.core.learning.SupervisedLearning
    public void addToSquaredErrorSum(double[] dArr) {
        throw new EncogEngineError("Not supported, and should not be called.");
    }

    @Override // org.neuroph.core.learning.SupervisedLearning, org.neuroph.core.learning.IterativeLearning
    public void doLearningEpoch(TrainingSet trainingSet) {
        this.previousEpochError = this.totalNetworkError;
        if (this.lastLearningType != this.learningType || this.lastTrainingSet != trainingSet) {
            this.lastTrainingSet = trainingSet;
            switch (this.learningType) {
                case ResilientPropagation:
                    this.training = new TrainFlatNetworkResilient(this.flat, this.lastTrainingSet);
                    break;
                case ManhattanUpdateRule:
                    this.training = new TrainFlatNetworkManhattan(this.flat, this.lastTrainingSet, this.learningRate);
                    break;
                case BackPropagation:
                    this.training = new TrainFlatNetworkBackPropagation(this.flat, this.lastTrainingSet, getLearningRate(), getMomentum());
                    break;
            }
            this.training.setNumThreads(this.numThreads);
            this.lastLearningType = this.learningType;
            this.lastTrainingSet = trainingSet;
        }
        this.training.iteration();
        this.totalNetworkError = this.training.getError();
        if (hasReachedStopCondition()) {
            stopLearning();
        }
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    public double getLearningRate() {
        return this.learningRate;
    }

    public FlatLearningType getLearningType() {
        return this.learningType;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public int getNumThreads() {
        return this.numThreads;
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setLearningType(FlatLearningType flatLearningType) {
        this.learningType = flatLearningType;
    }

    public void setMomentum(double d) {
        this.momentum = d;
    }

    public void setNumThreads(int i) {
        this.numThreads = i;
    }

    @Override // org.neuroph.core.learning.SupervisedLearning
    protected void updateNetworkWeights(double[] dArr) {
        throw new EncogEngineError("Method (updateNetworkWeights) is unimplemented and should not have been called.");
    }
}
