Warrior_EA/AI/Network.cl

228 lines
9.4 KiB
Common Lisp
Raw Permalink Normal View History

2025-05-30 16:35:54 +02:00
//--- by default some GPU doesn't support doubles
//--- cl_khr_fp64 directive is used to enable work with doubles
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
__kernel void FeedForward(__global double *matrix_w,
__global double *matrix_i,
__global double *matrix_o,
int inputs, int activation)
{
//barrier(CLK_GLOBAL_MEM_FENCE);
int i = get_global_id(0);
double sum = 0.0;
double4 inp, weight;
int shift = (inputs + 1) * i;
for(int k = 0; k <= inputs; k = k + 4)
{
switch(inputs - k)
{
case 0:
inp = (double4)(1, 0, 0, 0);
weight = (double4)(matrix_w[shift + k], 0, 0, 0);
break;
case 1:
inp = (double4)(matrix_i[k], 1, 0, 0);
weight = (double4)(matrix_w[shift + k], matrix_w[shift + k + 1], 0, 0);
break;
case 2:
inp = (double4)(matrix_i[k], matrix_i[k + 1], 1, 0);
weight = (double4)(matrix_w[shift + k], matrix_w[shift + k + 1], matrix_w[shift + k + 2], 0);
break;
case 3:
inp = (double4)(matrix_i[k], matrix_i[k + 1], matrix_i[k + 2], 1);
weight = (double4)(matrix_w[shift + k], matrix_w[shift + k + 1], matrix_w[shift + k + 2], matrix_w[shift + k + 3]);
break;
default:
inp = (double4)(matrix_i[k], matrix_i[k + 1], matrix_i[k + 2], matrix_i[k + 3]);
weight = (double4)(matrix_w[shift + k], matrix_w[shift + k + 1], matrix_w[shift + k + 2], matrix_w[shift + k + 3]);
break;
}
sum += dot(inp, weight);
}
switch(activation)
{
case 0:
sum = tanh(sum);
break;
case 1:
sum = 1 / (1 + exp(-clamp(sum, -50.0, 50.0)));
break;
}
matrix_o[i] = sum;
barrier(CLK_GLOBAL_MEM_FENCE);
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
__kernel void CaclOutputGradient(__global double *matrix_t,
__global double *matrix_o,
__global double *matrix_ig,
int activation)
{
int i = get_global_id(0);
double temp = 0;
double out = matrix_o[i];
switch(activation)
{
case 0:
temp = clamp(matrix_t[i], -1.0, 1.0) - out;
temp = temp * (1 - pow(out == 1 || out == -1 ? 0.99999999 : out, 2));
break;
case 1:
temp = clamp(matrix_t[i], 0.0, 1.0) - out;
temp = temp * (out == 0 || out == 1 ? 0.00000001 : (out * (1 - out)));
break;
}
matrix_ig[i] = temp;
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
__kernel void CaclHiddenGradient(__global double *matrix_w,
__global double *matrix_g,
__global double *matrix_o,
__global double *matrix_ig,
int outputs, int activation)
{
int i = get_global_id(0);
double sum = 0;
double out = matrix_o[i];
double4 grad, weight;
int shift = (outputs + 1) * i;
for(int k = 0; k < outputs; k += 4)
{
switch(outputs - k)
{
case 0:
grad = (double4)(matrix_g[k], 0, 0, 0);
weight = (double4)(matrix_w[shift + k], 0, 0, 0);
break;
case 1:
grad = (double4)(matrix_g[k], matrix_g[k + 1], 0, 0);
weight = (double4)(matrix_w[shift + k], matrix_w[shift + k + 1], 0, 0);
break;
case 2:
grad = (double4)(matrix_g[k], matrix_g[k + 1], matrix_g[k + 2], 0);
weight = (double4)(matrix_w[shift + k], matrix_w[shift + k + 1], matrix_w[shift + k + 2], 0);
break;
default:
grad = (double4)(matrix_g[k], matrix_g[k + 1], matrix_g[k + 2], matrix_g[k + 3]);
weight = (double4)(matrix_w[shift + k], matrix_w[shift + k + 1], matrix_w[shift + k + 2], matrix_w[shift + k + 3]);
break;
}
sum += dot(grad, weight);
}
switch(activation)
{
case 0:
sum = clamp(sum + out, -1.0, 1.0) - out;
sum = sum * (1 - pow(out == 1 || out == -1 ? 0.99999999 : out, 2));
break;
case 1:
sum = clamp(sum + out, 0.0, 1.0) - out;
sum = sum * (out == 0 || out == 1 ? 0.00000001 : (out * (1 - out)));
break;
}
matrix_ig[i] = sum;
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
__kernel void UpdateWeightsMomentum(__global double *matrix_w,
__global double *matrix_g,
__global double *matrix_i,
__global double *matrix_dw,
int inputs, double learning_rates, double momentum)
{
int i = get_global_id(0);
int j = get_global_id(1);
int wi = i * (inputs + 1) + j;
double delta = learning_rates * matrix_g[i] * (j < inputs ? matrix_i[j] : 1) + momentum * matrix_dw[wi];
matrix_dw[wi] = delta;
matrix_w[wi] += delta;
};
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
__kernel void UpdateWeightsAdam(__global double *matrix_w,
__global const double *matrix_g,
__global const double *matrix_i,
__global double *matrix_m,
__global double *matrix_v,
const int inputs, const double l, const double b1, const double b2)
{
const int i = get_global_id(0);
const int j = get_global_id(1);
const int wi = i * (inputs + 1) + j * 4;
double4 m, v, weight, inp;
switch(inputs - j * 4)
{
case 0:
inp = (double4)(1, 0, 0, 0);
weight = (double4)(matrix_w[wi], 0, 0, 0);
m = (double4)(matrix_m[wi], 0, 0, 0);
v = (double4)(matrix_v[wi], 0, 0, 0);
break;
case 1:
inp = (double4)(matrix_i[j], 1, 0, 0);
weight = (double4)(matrix_w[wi], matrix_w[wi + 1], 0, 0);
m = (double4)(matrix_m[wi], matrix_m[wi + 1], 0, 0);
v = (double4)(matrix_v[wi], matrix_v[wi + 1], 0, 0);
break;
case 2:
inp = (double4)(matrix_i[j], matrix_i[j + 1], 1, 0);
weight = (double4)(matrix_w[wi], matrix_w[wi + 1], matrix_w[wi + 2], 0);
m = (double4)(matrix_m[wi], matrix_m[wi + 1], matrix_m[wi + 2], 0);
v = (double4)(matrix_v[wi], matrix_v[wi + 1], matrix_v[wi + 2], 0);
break;
case 3:
inp = (double4)(matrix_i[j], matrix_i[j + 1], matrix_i[j + 2], 1);
weight = (double4)(matrix_w[wi], matrix_w[wi + 1], matrix_w[wi + 2], matrix_w[wi + 3]);
m = (double4)(matrix_m[wi], matrix_m[wi + 1], matrix_m[wi + 2], matrix_m[wi + 3]);
v = (double4)(matrix_v[wi], matrix_v[wi + 1], matrix_v[wi + 2], matrix_v[wi + 3]);
break;
default:
inp = (double4)(matrix_i[j], matrix_i[j + 1], matrix_i[j + 2], matrix_i[j + 3]);
weight = (double4)(matrix_w[wi], matrix_w[wi + 1], matrix_w[wi + 2], matrix_w[wi + 3]);
m = (double4)(matrix_m[wi], matrix_m[wi + 1], matrix_m[wi + 2], matrix_m[wi + 3]);
v = (double4)(matrix_v[wi], matrix_v[wi + 1], matrix_v[wi + 2], matrix_v[wi + 3]);
break;
}
double4 g = matrix_g[i] * inp;
double4 mt = b1 * m + (1 - b1) * g;
double4 vt = sqrt(b2 * v + (1 - b2) * pow(g, 2));
double4 delta = l * mt / (vt > 0 ? vt : l * 10);
switch(inputs - j * 4)
{
case 2:
matrix_w[wi + 2] += delta.s2;
matrix_m[wi + 2] = mt.s2;
matrix_v[wi + 2] = vt.s2;
case 1:
matrix_w[wi + 1] += delta.s1;
matrix_m[wi + 1] = mt.s1;
matrix_v[wi + 1] = vt.s1;
case 0:
matrix_w[wi] += delta.s0;
matrix_m[wi] = mt.s0;
matrix_v[wi] = vt.s0;
break;
default:
matrix_w[wi] += delta.s0;
matrix_m[wi] = mt.s0;
matrix_v[wi] = vt.s0;
matrix_w[wi + 1] += delta.s1;
matrix_m[wi + 1] = mt.s1;
matrix_v[wi + 1] = vt.s1;
matrix_w[wi + 2] += delta.s2;
matrix_m[wi + 2] = mt.s2;
matrix_v[wi + 2] = vt.s2;
matrix_w[wi + 3] += delta.s3;
matrix_m[wi + 3] = mt.s3;
matrix_v[wi + 3] = vt.s3;
break;
}
};
//+------------------------------------------------------------------+