package org.encog.engine.network.train.gradient;

import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.prop.TrainFlatNetworkProp;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.Stopwatch;

/* loaded from: classes.dex */
public class GradientWorkerCPU implements FlatGradientWorker {
    private final double[] actual;
    private long elapsedTime;
    private final double[] gradients;
    private final int high;
    private final int[] layerCounts;
    private final double[] layerDelta;
    private int[] layerFeedCounts;
    private final int[] layerIndex;
    private final double[] layerOutput;
    private final int low;
    private final FlatNetwork network;
    private final TrainFlatNetworkProp owner;
    private final EngineData pair;
    private final EngineIndexableSet training;
    private final int[] weightIndex;
    private final double[] weights;
    private final ErrorCalculation errorCalculation = new ErrorCalculation();
    private final Stopwatch stopwatch = new Stopwatch();

    public GradientWorkerCPU(FlatNetwork flatNetwork, TrainFlatNetworkProp trainFlatNetworkProp, EngineIndexableSet engineIndexableSet, int i, int i2) {
        this.network = flatNetwork;
        this.training = engineIndexableSet;
        this.low = i;
        this.high = i2;
        this.owner = trainFlatNetworkProp;
        this.layerDelta = new double[flatNetwork.getLayerOutput().length];
        this.gradients = new double[flatNetwork.getWeights().length];
        this.actual = new double[flatNetwork.getOutputCount()];
        this.weights = flatNetwork.getWeights();
        this.layerIndex = flatNetwork.getLayerIndex();
        this.layerCounts = flatNetwork.getLayerCounts();
        this.weightIndex = flatNetwork.getWeightIndex();
        this.layerOutput = flatNetwork.getLayerOutput();
        this.layerFeedCounts = flatNetwork.getLayerFeedCounts();
        this.pair = BasicEngineData.createPair(flatNetwork.getInputCount(), flatNetwork.getOutputCount());
    }

    private void process(double[] dArr, double[] dArr2) {
        this.network.compute(dArr, this.actual);
        this.errorCalculation.updateError(this.actual, dArr2);
        for (int i = 0; i < this.actual.length; i++) {
            this.layerDelta[i] = this.network.getActivationFunctions()[0].derivativeFunction(this.actual[i]) * (dArr2[i] - this.actual[i]);
        }
        for (int beginTraining = this.network.getBeginTraining(); beginTraining < this.network.getEndTraining(); beginTraining++) {
            processLevel(beginTraining);
        }
    }

    private void processLevel(int i) {
        int i2 = this.layerIndex[i + 1];
        int i3 = this.layerIndex[i];
        int i4 = this.layerCounts[i + 1];
        int i5 = this.layerFeedCounts[i];
        int i6 = this.weightIndex[i];
        ActivationFunction activationFunction = this.network.getActivationFunctions()[i + 1];
        int i7 = 0;
        int i8 = i2;
        while (i7 < i4) {
            double d = this.layerOutput[i8];
            double d2 = FlatNetwork.NO_BIAS_ACTIVATION;
            int i9 = i6 + i7;
            int i10 = i3;
            for (int i11 = 0; i11 < i5; i11++) {
                double[] dArr = this.gradients;
                dArr[i9] = dArr[i9] + (this.layerDelta[i10] * d);
                d2 += this.weights[i9] * this.layerDelta[i10];
                i9 += i4;
                i10++;
            }
            this.layerDelta[i8] = activationFunction.derivativeFunction(this.layerOutput[i8]) * d2;
            i7++;
            i8++;
        }
    }

    @Override // org.encog.engine.network.train.gradient.FlatGradientWorker
    public long getElapsedTime() {
        return this.elapsedTime;
    }

    @Override // org.encog.engine.network.train.gradient.FlatGradientWorker
    public FlatNetwork getNetwork() {
        return this.network;
    }

    @Override // org.encog.engine.network.train.gradient.FlatGradientWorker
    public double[] getWeights() {
        return this.weights;
    }

    @Override // org.encog.engine.concurrency.EngineTask
    public void run() {
        try {
            this.stopwatch.reset();
            this.stopwatch.start();
            this.errorCalculation.reset();
            for (int i = this.low; i <= this.high; i++) {
                this.training.getRecord(i, this.pair);
                process(this.pair.getInputArray(), this.pair.getIdealArray());
            }
            this.owner.report(this.gradients, this.errorCalculation.calculate(), null);
            EngineArray.fill(this.gradients, FlatNetwork.NO_BIAS_ACTIVATION);
            this.stopwatch.stop();
            this.elapsedTime = this.stopwatch.getElapsedTicks();
        } catch (Throwable th) {
            this.owner.report(null, FlatNetwork.NO_BIAS_ACTIVATION, th);
        }
    }
}
