package org.neuroph.nnet.learning;

import java.util.Iterator;
import org.encog.engine.network.flat.FlatNetwork;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
import org.neuroph.util.NeuralNetworkCODEC;

/* loaded from: classes.dex */
public class SimulatedAnnealingLearning extends SupervisedLearning {
    private static final long serialVersionUID = 1;
    private double[] bestWeights;
    private int cycles;
    protected NeuralNetwork network;
    private double startTemperature;
    private double stopTemperature;
    protected double temperature;
    private double[] weights;

    public SimulatedAnnealingLearning(NeuralNetwork neuralNetwork) {
        this(neuralNetwork, 10.0d, 2.0d, 1000);
    }

    public SimulatedAnnealingLearning(NeuralNetwork neuralNetwork, double d, double d2, int i) {
        this.network = neuralNetwork;
        this.temperature = d;
        this.startTemperature = d;
        this.stopTemperature = d2;
        this.cycles = i;
        this.weights = new double[NeuralNetworkCODEC.determineArraySize(neuralNetwork)];
        this.bestWeights = new double[NeuralNetworkCODEC.determineArraySize(neuralNetwork)];
        NeuralNetworkCODEC.network2array(neuralNetwork, this.weights);
        NeuralNetworkCODEC.network2array(neuralNetwork, this.bestWeights);
    }

    private double determineError(TrainingSet trainingSet) {
        Iterator it = trainingSet.iterator();
        double d = 0.0d;
        while (it.hasNext() && !isStopped()) {
            SupervisedTrainingElement supervisedTrainingElement = (SupervisedTrainingElement) it.next();
            this.neuralNetwork.setInput(supervisedTrainingElement.getInput());
            this.neuralNetwork.calculate();
            double[] calculateOutputError = calculateOutputError(supervisedTrainingElement.getDesiredOutput(), this.neuralNetwork.getOutput());
            updateTotalNetworkError(calculateOutputError);
            double d2 = 0.0d;
            for (double d3 : calculateOutputError) {
                d2 += d3 * d3;
            }
            d += d2 / (calculateOutputError.length * 2);
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.core.learning.SupervisedLearning
    public void addToSquaredErrorSum(double[] dArr) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override // org.neuroph.core.learning.SupervisedLearning, org.neuroph.core.learning.IterativeLearning
    public void doLearningEpoch(TrainingSet trainingSet) {
        System.arraycopy(this.weights, 0, this.bestWeights, 0, this.weights.length);
        double determineError = determineError(trainingSet);
        this.temperature = this.startTemperature;
        for (int i = 0; i < this.cycles; i++) {
            randomize();
            double determineError2 = determineError(trainingSet);
            if (determineError2 < determineError) {
                System.arraycopy(this.weights, 0, this.bestWeights, 0, this.weights.length);
                determineError = determineError2;
            } else {
                System.arraycopy(this.bestWeights, 0, this.weights, 0, this.weights.length);
            }
            NeuralNetworkCODEC.array2network(this.bestWeights, this.network);
            this.temperature = Math.exp(Math.log(this.stopTemperature / this.startTemperature) / (this.cycles - 1)) * this.temperature;
        }
        this.previousEpochError = this.totalNetworkError;
        this.totalNetworkError = determineError;
        if (hasReachedStopCondition()) {
            stopLearning();
        }
    }

    public NeuralNetwork getNetwork() {
        return this.network;
    }

    public void randomize() {
        for (int i = 0; i < this.weights.length; i++) {
            this.weights[i] = (((0.5d - Math.random()) / this.startTemperature) * this.temperature) + this.weights[i];
        }
        NeuralNetworkCODEC.array2network(this.weights, this.network);
    }

    @Override // org.neuroph.core.learning.SupervisedLearning
    protected void updateNetworkWeights(double[] dArr) {
    }

    protected void updateTotalNetworkError(double[] dArr) {
        double d = FlatNetwork.NO_BIAS_ACTIVATION;
        for (double d2 : dArr) {
            d += d2 * d2;
        }
        this.totalNetworkError = (d / (dArr.length * 2)) + this.totalNetworkError;
    }
}
