package org.encog.engine.network.train.prop;

import org.encog.engine.EncogEngineError;
import org.encog.engine.concurrency.DetermineWorkload;
import org.encog.engine.concurrency.EngineConcurrency;
import org.encog.engine.concurrency.TaskGroup;
import org.encog.engine.data.EngineDataSet;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.TrainFlatNetwork;
import org.encog.engine.network.train.gradient.FlatGradientWorker;
import org.encog.engine.network.train.gradient.GradientWorkerCPU;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.IntRange;

/* loaded from: classes.dex */
public abstract class TrainFlatNetworkProp implements TrainFlatNetwork {
    protected double currentError;
    protected double[] gradients;
    protected final EngineIndexableSet indexable;
    protected int iteration;
    protected double[] lastGradient;
    protected final FlatNetwork network;
    protected int numThreads;
    protected Throwable reportedException;
    protected double totalError;
    protected final EngineDataSet training;
    protected FlatGradientWorker[] workers;

    public TrainFlatNetworkProp(FlatNetwork flatNetwork, EngineDataSet engineDataSet) {
        if (!(engineDataSet instanceof EngineIndexableSet)) {
            throw new EncogEngineError("Training data must be Indexable for this training type.");
        }
        this.training = engineDataSet;
        this.network = flatNetwork;
        this.gradients = new double[this.network.getWeights().length];
        this.lastGradient = new double[this.network.getWeights().length];
        this.indexable = (EngineIndexableSet) engineDataSet;
        this.numThreads = 0;
        this.reportedException = null;
    }

    private void copyContexts() {
        for (int i = 0; i < this.workers.length - 1; i++) {
            EngineArray.arrayCopy(this.workers[i].getNetwork().getLayerOutput(), this.workers[i + 1].getNetwork().getLayerOutput());
        }
    }

    private void init() {
        DetermineWorkload determineWorkload = new DetermineWorkload(this.numThreads, (int) this.indexable.getRecordCount());
        this.workers = new FlatGradientWorker[determineWorkload.getThreadCount()];
        int i = 0;
        for (IntRange intRange : determineWorkload.calculateWorkers()) {
            this.workers[i] = new GradientWorkerCPU(this.network.clone(), this, this.indexable.openAdditional(), intRange.getLow(), intRange.getHigh());
            i++;
        }
    }

    public void calculateGradients() {
        if (this.workers == null) {
            init();
        }
        this.workers[0].getNetwork().clearContext();
        this.totalError = FlatNetwork.NO_BIAS_ACTIVATION;
        if (this.workers.length > 1) {
            TaskGroup createTaskGroup = EngineConcurrency.getInstance().createTaskGroup();
            for (FlatGradientWorker flatGradientWorker : this.workers) {
                EngineConcurrency.getInstance().processTask(flatGradientWorker, createTaskGroup);
            }
            createTaskGroup.waitForComplete();
        } else {
            this.workers[0].run();
        }
        this.currentError = this.totalError / this.workers.length;
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public void finishTraining() {
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public double getError() {
        return this.currentError;
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public int getIteration() {
        return this.iteration;
    }

    public double[] getLastGradient() {
        return this.lastGradient;
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public FlatNetwork getNetwork() {
        return this.network;
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public int getNumThreads() {
        return this.numThreads;
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public EngineDataSet getTraining() {
        return this.training;
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public void iteration() {
        this.iteration++;
        calculateGradients();
        if (this.network.isLimited()) {
            learnLimited();
        } else {
            learn();
        }
        for (FlatGradientWorker flatGradientWorker : this.workers) {
            EngineArray.arrayCopy(this.network.getWeights(), 0, flatGradientWorker.getWeights(), 0, this.network.getWeights().length);
        }
        copyContexts();
        if (this.reportedException != null) {
            throw new EncogEngineError(this.reportedException);
        }
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public void iteration(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            iteration();
        }
    }

    protected void learn() {
        double[] weights = this.network.getWeights();
        for (int i = 0; i < this.gradients.length; i++) {
            weights[i] = weights[i] + updateWeight(this.gradients, this.lastGradient, i);
            this.gradients[i] = 0.0d;
        }
    }

    protected void learnLimited() {
        double connectionLimit = this.network.getConnectionLimit();
        double[] weights = this.network.getWeights();
        for (int i = 0; i < this.gradients.length; i++) {
            if (weights[i] < connectionLimit) {
                weights[i] = 0.0d;
            } else {
                weights[i] = weights[i] + updateWeight(this.gradients, this.lastGradient, i);
            }
            this.gradients[i] = 0.0d;
        }
    }

    public void report(double[] dArr, double d, Throwable th) {
        synchronized (this) {
            if (th == null) {
                for (int i = 0; i < dArr.length; i++) {
                    double[] dArr2 = this.gradients;
                    dArr2[i] = dArr2[i] + dArr[i];
                }
                this.totalError += d;
            } else {
                this.reportedException = th;
            }
        }
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public void setIteration(int i) {
        this.iteration = i;
    }

    @Override // org.encog.engine.network.train.TrainFlatNetwork
    public void setNumThreads(int i) {
        this.numThreads = i;
    }

    public abstract double updateWeight(double[] dArr, double[] dArr2, int i);
}
