package org.neuroph.samples;

import java.util.Arrays;
import org.encog.engine.network.flat.FlatNetwork;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.learning.SupervisedTrainingElement;
import org.neuroph.core.learning.TrainingSet;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.TransferFunctionType;

/* loaded from: classes.dex */
public class XorMultiLayerPerceptronSample {
    public static void main(String[] strArr) {
        TrainingSet trainingSet = new TrainingSet(2, 1);
        trainingSet.addElement(new SupervisedTrainingElement(new double[]{FlatNetwork.NO_BIAS_ACTIVATION, FlatNetwork.NO_BIAS_ACTIVATION}, new double[]{FlatNetwork.NO_BIAS_ACTIVATION}));
        trainingSet.addElement(new SupervisedTrainingElement(new double[]{FlatNetwork.NO_BIAS_ACTIVATION, 1.0d}, new double[]{1.0d}));
        trainingSet.addElement(new SupervisedTrainingElement(new double[]{1.0d, FlatNetwork.NO_BIAS_ACTIVATION}, new double[]{1.0d}));
        trainingSet.addElement(new SupervisedTrainingElement(new double[]{1.0d, 1.0d}, new double[]{FlatNetwork.NO_BIAS_ACTIVATION}));
        MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(TransferFunctionType.TANH, 2, 3, 1);
        if (multiLayerPerceptron.getLearningRule() instanceof MomentumBackpropagation) {
            ((MomentumBackpropagation) multiLayerPerceptron.getLearningRule()).setBatchMode(true);
        }
        System.out.println("Training neural network...");
        multiLayerPerceptron.learn(trainingSet);
        System.out.println("Testing trained neural network");
        testNeuralNetwork(multiLayerPerceptron, trainingSet);
        multiLayerPerceptron.save("myMlPerceptron.nnet");
        NeuralNetwork load = NeuralNetwork.load("myMlPerceptron.nnet");
        System.out.println("Testing loaded neural network");
        testNeuralNetwork(load, trainingSet);
    }

    public static void testNeuralNetwork(NeuralNetwork neuralNetwork, TrainingSet trainingSet) {
        for (SupervisedTrainingElement supervisedTrainingElement : trainingSet.elements()) {
            neuralNetwork.setInput(supervisedTrainingElement.getInput());
            neuralNetwork.calculate();
            double[] output = neuralNetwork.getOutput();
            System.out.print("Input: " + Arrays.toString(supervisedTrainingElement.getInput()));
            System.out.println(" Output: " + Arrays.toString(output));
        }
    }
}
