package org.neuroph.nnet.flat;

import java.util.HashSet;
import java.util.Iterator;
import org.encog.engine.EncogEngine;
import org.encog.engine.EncogEngineError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.engine.network.flat.FlatLayer;
import org.encog.engine.network.flat.FlatNetwork;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;
import org.neuroph.core.transfer.Linear;
import org.neuroph.core.transfer.Sigmoid;
import org.neuroph.core.transfer.Tanh;
import org.neuroph.core.transfer.TransferFunction;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.comp.BiasNeuron;
import org.neuroph.nnet.comp.InputNeuron;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.plugins.PluginBase;

/* loaded from: classes.dex */
public class FlatNetworkPlugin extends PluginBase {
    public static final String PLUGIN_NAME = "FlatNetworkPlugin";
    private static final long serialVersionUID = 1;
    public FlatNetwork flatNetwork;

    public FlatNetworkPlugin(FlatNetwork flatNetwork) {
        super(PLUGIN_NAME);
        this.flatNetwork = flatNetwork;
    }

    private static FlatLayer flattenLayer(Layer layer) {
        double slope;
        ActivationFunction activationTANH;
        int i;
        HashSet hashSet = new HashSet();
        TransferFunction transferFunction = null;
        int i2 = 0;
        int i3 = 0;
        boolean z = false;
        for (Neuron neuron : layer.getNeurons()) {
            if (neuron.getClass() == InputNeuron.class) {
                z = true;
            }
            if (neuron.getClass() == Neuron.class || neuron.getClass() == InputNeuron.class) {
                i3++;
                transferFunction = neuron.getTransferFunction();
                hashSet.add(transferFunction.getClass());
                i = i2;
            } else {
                i = neuron.getClass() == BiasNeuron.class ? i2 + 1 : i2;
            }
            i2 = i;
        }
        if (hashSet.size() > 1) {
            return null;
        }
        Class cls = (Class) hashSet.iterator().next();
        if (z) {
            activationTANH = new ActivationLinear();
            slope = 1.0d;
        } else if (cls == Linear.class) {
            slope = ((Linear) transferFunction).getSlope();
            activationTANH = new ActivationLinear();
        } else if (cls == Sigmoid.class) {
            slope = ((Sigmoid) transferFunction).getSlope();
            activationTANH = new ActivationSigmoid();
        } else {
            if (cls != Tanh.class) {
                return null;
            }
            slope = ((Tanh) transferFunction).getSlope();
            activationTANH = new ActivationTANH();
        }
        if (i2 > 1) {
            return null;
        }
        return new FlatLayer(activationTANH, i3, i2 == 1 ? 1.0d : FlatNetwork.NO_BIAS_ACTIVATION, new double[]{slope});
    }

    private static boolean flattenMultiLayerPerceptron(MultiLayerPerceptron multiLayerPerceptron) {
        FlatLayer[] flatLayerArr = new FlatLayer[multiLayerPerceptron.getLayers().size()];
        Iterator it = multiLayerPerceptron.getLayers().iterator();
        int i = 0;
        while (it.hasNext()) {
            FlatLayer flattenLayer = flattenLayer((Layer) it.next());
            if (flattenLayer == null) {
                return false;
            }
            flatLayerArr[i] = flattenLayer;
            i++;
        }
        FlatNetwork flatNetwork = new FlatNetwork(flatLayerArr);
        multiLayerPerceptron.addPlugin(new FlatNetworkPlugin(flatNetwork));
        multiLayerPerceptron.setLearningRule(new FlatNetworkLearning(flatNetwork));
        flattenWeights(flatNetwork, multiLayerPerceptron);
        return true;
    }

    public static boolean flattenNeuralNetworkNetwork(NeuralNetwork neuralNetwork) {
        if (neuralNetwork instanceof MultiLayerPerceptron) {
            return flattenMultiLayerPerceptron((MultiLayerPerceptron) neuralNetwork);
        }
        return false;
    }

    private static void flattenWeights(FlatNetwork flatNetwork, NeuralNetwork neuralNetwork) {
        double[] weights = flatNetwork.getWeights();
        int i = 0;
        for (int size = neuralNetwork.getLayers().size() - 1; size > 0; size--) {
            Iterator it = ((Layer) neuralNetwork.getLayers().get(size)).getNeurons().iterator();
            while (it.hasNext()) {
                for (Connection connection : ((Neuron) it.next()).getInputConnections()) {
                    if (i >= weights.length) {
                        throw new EncogEngineError("Weight size mismatch.");
                    }
                    Weight weight = connection.getWeight();
                    FlatWeight flatWeight = new FlatWeight(weights, i);
                    flatWeight.setValue(weight.getValue());
                    connection.setWeight(flatWeight);
                    i++;
                }
            }
        }
    }

    public static void initCL() {
        EncogEngine.getInstance().initCL();
    }

    public static void shutdown() {
        EncogEngine.getInstance().shutdown();
    }

    public static boolean unFlattenNeuralNetworkNetwork(NeuralNetwork neuralNetwork) {
        for (int size = neuralNetwork.getLayers().size() - 1; size > 0; size--) {
            Iterator it = ((Layer) neuralNetwork.getLayers().get(size)).getNeurons().iterator();
            while (it.hasNext()) {
                for (Connection connection : ((Neuron) it.next()).getInputConnections()) {
                    Weight weight = connection.getWeight();
                    if (weight instanceof FlatWeight) {
                        Weight weight2 = new Weight(weight.getValue());
                        ((MomentumBackpropagation.MomentumWeightTrainingData) weight2.getTrainingData()).previousValue = ((MomentumBackpropagation.MomentumWeightTrainingData) weight2.getTrainingData()).previousValue;
                        connection.setWeight(weight2);
                    }
                }
            }
        }
        neuralNetwork.removePlugin(FlatNetworkPlugin.class);
        return true;
    }

    public FlatNetwork getFlatNetwork() {
        return this.flatNetwork;
    }
}
