package org.encog.engine.opencl.kernels;

import java.util.Map;
import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.prop.OpenCLTrainingProfile;
import org.encog.engine.opencl.EncogCLDevice;
import org.encog.engine.opencl.EncogCLQueue;
import org.encog.engine.opencl.exceptions.OpenCLError;
import org.encog.engine.opencl.exceptions.OutOfOpenCLResources;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ResourceLoader;
import org.jocl.CLException;
import org.jocl.cl_mem;

/* loaded from: classes.dex */
public class KernelNetworkTrain extends EncogKernel {
    public static final int PARRAY_INPUT_COUNT = 0;
    public static final int PARRAY_ITEMS_PER = 5;
    public static final int PARRAY_ITERATIONS = 6;
    public static final int PARRAY_LAYER_COUNT = 2;
    public static final int PARRAY_LEARN = 3;
    public static final int PARRAY_OUTPUT_COUNT = 1;
    public static final int PARRAY_START = 4;
    private cl_mem activationTypeBuffer;
    private final EncogCLDevice device;
    private cl_mem errorBuffer;
    private float[] errors;
    private final FlatNetwork flat;
    private cl_mem gradientInBuffer;
    private cl_mem gradientOutBuffer;
    private final float[] gradients;
    private final float[] idealArray;
    private cl_mem idealBuffer;
    private final float[] inputArray;
    private cl_mem inputBuffer;
    private cl_mem layerCountBuffer;
    private int layerDeltaSize;
    private cl_mem layerFeedCountBuffer;
    private cl_mem layerIndexBuffer;
    private final int[] paramArray;
    private cl_mem paramBuffer;
    private final float[] slopeArray;
    private cl_mem slopeBuffer;
    private float[] tempDataArray;
    private cl_mem tempDataInBuffer;
    private cl_mem tempDataOutBuffer;
    private final EngineIndexableSet training;
    private final int trainingLength;
    private final float[] weightInArray;
    private cl_mem weightInArrayBuffer;
    private cl_mem weightIndexBuffer;
    private final float[] weightOutArray;
    private cl_mem weightOutArrayBuffer;

    public KernelNetworkTrain(EncogCLDevice encogCLDevice, FlatNetwork flatNetwork, EngineIndexableSet engineIndexableSet, int i) {
        super(encogCLDevice, "org/encog/engine/resources/KernelNetTrain.txt", "NetworkTrain");
        this.training = engineIndexableSet;
        this.trainingLength = (int) this.training.getRecordCount();
        this.device = encogCLDevice;
        this.flat = flatNetwork;
        this.weightInArray = new float[flatNetwork.getWeights().length];
        this.weightOutArray = new float[flatNetwork.getWeights().length];
        this.tempDataArray = new float[i];
        this.slopeArray = new float[flatNetwork.getActivationFunctions().length];
        this.gradients = new float[flatNetwork.getWeights().length];
        this.layerDeltaSize = 0;
        for (int i2 = 0; i2 < flatNetwork.getLayerCounts().length; i2++) {
            this.layerDeltaSize += flatNetwork.getLayerCounts()[i2];
        }
        int i3 = 0;
        int i4 = 0;
        while (i3 < flatNetwork.getActivationFunctions().length) {
            this.slopeArray[i4] = (float) flatNetwork.getActivationFunctions()[i3].getParams()[0];
            i3++;
            i4++;
        }
        int inputCount = flatNetwork.getInputCount();
        int outputCount = flatNetwork.getOutputCount();
        this.inputArray = new float[inputCount * this.trainingLength];
        this.idealArray = new float[this.trainingLength * outputCount];
        this.paramArray = new int[10];
        EngineData createPair = BasicEngineData.createPair(flatNetwork.getInputCount(), flatNetwork.getOutputCount());
        int i5 = 0;
        int i6 = 0;
        for (int i7 = 0; i7 < this.trainingLength; i7++) {
            engineIndexableSet.getRecord(i7, createPair);
            int i8 = 0;
            while (i8 < flatNetwork.getInputCount()) {
                this.inputArray[i6] = (float) createPair.getInputArray()[i8];
                i8++;
                i6++;
            }
            int i9 = 0;
            while (i9 < flatNetwork.getOutputCount()) {
                this.idealArray[i5] = (float) createPair.getIdealArray()[i9];
                i9++;
                i5++;
            }
        }
    }

    public void assignWorkgroupSizes(int i, int i2) {
        int min = Math.min(i, i2);
        setLocalWork(Math.min(getMaxWorkGroupSize(), min));
        setGlobalWork(min);
    }

    public void calculate(int i, int i2, boolean z, int i3) {
        prepareKernel();
        this.paramArray[3] = z ? 1 : 0;
        this.paramArray[4] = i;
        this.paramArray[5] = i2;
        this.paramArray[6] = i3;
        EngineArray.arrayCopy(this.flat.getWeights(), this.weightInArray);
        setArg(0, this.paramBuffer);
        setArg(1, this.errorBuffer);
        setArg(2, this.layerIndexBuffer);
        setArg(3, this.layerCountBuffer);
        setArg(4, this.layerFeedCountBuffer);
        setArg(5, this.weightIndexBuffer);
        setArg(6, this.inputBuffer);
        setArg(7, this.idealBuffer);
        setArg(8, this.weightInArrayBuffer);
        setArg(9, this.weightOutArrayBuffer);
        setArg(10, this.gradientOutBuffer);
        setArg(11, this.activationTypeBuffer);
        setArg(12, this.slopeBuffer);
        setArg(13, this.tempDataInBuffer);
        setArg(14, this.tempDataOutBuffer);
        setArg(15, this.gradientInBuffer);
        try {
            EncogCLQueue queue = this.device.getQueue();
            EngineArray.fill(this.gradients, 0.0f);
            if (z) {
                this.paramArray[3] = 1;
            } else {
                this.paramArray[3] = 0;
            }
            this.paramArray[4] = i;
            queue.array2Buffer(this.weightInArray, this.weightInArrayBuffer);
            queue.array2Buffer(this.tempDataArray, this.tempDataInBuffer);
            queue.array2Buffer(this.gradients, this.gradientInBuffer);
            queue.array2Buffer(this.paramArray, this.paramBuffer);
            queue.execute(this);
            queue.waitFinish();
            queue.buffer2Array(this.errorBuffer, this.errors);
            queue.buffer2Array(this.weightOutArrayBuffer, this.weightOutArray);
            queue.buffer2Array(this.tempDataOutBuffer, this.tempDataArray);
            queue.buffer2Array(this.gradientOutBuffer, this.gradients);
        } catch (Exception e) {
            throw new OpenCLError(e);
        } catch (CLException e2) {
            if (!e2.getMessage().equals("CL_OUT_OF_RESOURCES")) {
                throw new OpenCLError((Throwable) e2);
            }
            throw new OutOfOpenCLResources(e2);
        }
    }

    public void compile(Map map, OpenCLTrainingProfile openCLTrainingProfile, FlatNetwork flatNetwork) {
        ActivationFunction activationFunction = flatNetwork.getActivationFunctions()[0];
        setCLSource("#define ACTIVATION(x,slope)" + activationFunction.getOpenCLExpression(false) + "\r\n#define DERIVATIVE(x,slope)" + activationFunction.getOpenCLExpression(true) + "\r\n" + ResourceLoader.loadString(getSourceName()));
        compile(map);
        openCLTrainingProfile.calculateKernelParams(this, this.training);
        init(openCLTrainingProfile);
    }

    public float[] getErrors() {
        return this.errors;
    }

    public float[] getTempDataArray() {
        return this.tempDataArray;
    }

    public float[] getWeightOutArray() {
        return this.weightOutArray;
    }

    public void init(OpenCLTrainingProfile openCLTrainingProfile) {
        int kernelGlobalWorkgroup = openCLTrainingProfile.getKernelGlobalWorkgroup();
        int kernelGlobalWorkgroup2 = openCLTrainingProfile.getKernelGlobalWorkgroup() * this.flat.getWeights().length;
        this.errors = new float[kernelGlobalWorkgroup];
        this.paramArray[0] = this.flat.getInputCount();
        this.paramArray[1] = this.flat.getOutputCount();
        this.paramArray[2] = this.flat.getLayerCounts().length;
        this.inputBuffer = createArrayReadOnly(this.inputArray);
        this.idealBuffer = createArrayReadOnly(this.idealArray);
        this.errorBuffer = createFloatArrayWriteOnly(kernelGlobalWorkgroup);
        this.gradientOutBuffer = createFloatArrayWriteOnly(kernelGlobalWorkgroup2);
        this.gradientInBuffer = createArrayReadOnly(this.gradients);
        this.paramBuffer = createArrayReadOnly(this.paramArray);
        this.layerIndexBuffer = createArrayReadOnly(this.flat.getLayerIndex());
        this.layerCountBuffer = createArrayReadOnly(this.flat.getLayerCounts());
        this.layerFeedCountBuffer = createArrayReadOnly(this.flat.getLayerFeedCounts());
        this.weightInArrayBuffer = createArrayReadOnly(this.weightInArray);
        this.weightOutArrayBuffer = createFloatArrayWriteOnly(this.weightInArray.length);
        this.weightIndexBuffer = createArrayReadOnly(this.flat.getWeightIndex());
        this.activationTypeBuffer = createArrayReadOnly(this.flat.getLayerCounts());
        this.slopeBuffer = createArrayReadOnly(this.slopeArray);
        this.tempDataInBuffer = createArrayReadOnly(this.tempDataArray);
        this.tempDataOutBuffer = createFloatArrayWriteOnly(this.tempDataArray.length);
    }

    @Override // org.encog.engine.opencl.kernels.EncogKernel
    public void release() {
        super.release();
        releaseBuffer(this.activationTypeBuffer);
        releaseBuffer(this.errorBuffer);
        releaseBuffer(this.gradientOutBuffer);
        releaseBuffer(this.gradientInBuffer);
        releaseBuffer(this.idealBuffer);
        releaseBuffer(this.inputBuffer);
        releaseBuffer(this.layerCountBuffer);
        releaseBuffer(this.layerFeedCountBuffer);
        releaseBuffer(this.layerIndexBuffer);
        releaseBuffer(this.paramBuffer);
        releaseBuffer(this.slopeBuffer);
        releaseBuffer(this.tempDataInBuffer);
        releaseBuffer(this.tempDataOutBuffer);
        releaseBuffer(this.weightInArrayBuffer);
        releaseBuffer(this.weightIndexBuffer);
        releaseBuffer(this.weightOutArrayBuffer);
    }

    public void setTempDataArray(float[] fArr) {
        this.tempDataArray = fArr;
    }
}
