Original_NNB/MQL5/Include/NeuroNetworksBook/realization/activation.mqh

424 lines
34 KiB
MQL5
Raw Permalink Normal View History

2025-05-30 16:15:14 +02:00
<EFBFBD><EFBFBD>//+------------------------------------------------------------------+
//| activation.mqh |
//| Copyright 2021, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2021, MetaQuotes Ltd."
#property link "https://www.mql5.com"
//+------------------------------------------------------------------+
//| Includes |
//+------------------------------------------------------------------+
#include "bufferdouble.mqh"
//+------------------------------------------------------------------+
//| Class CActivation |
//| 07=0G5=85: ;0AA 4;O @50;870F88 0;3>@8B<>2 DC=:F88 0:B820F88 |
//| 8 5Q ?@>872>4=>9 |
//+------------------------------------------------------------------+
class CActivation : protected CBufferDouble
{
protected:
ENUM_ACTIVATION m_eFunction;
double m_adParams[2];
//--- $C=:F88 0:B820F88
double LineActivation(double value); //8=59=0O DC=:F8O 0:B820F88
double SigmoidActivation(double value); //!83<>840
double TanhActivation(double value); //TANH
double LReLUActivation(double value); //LReLU
double SwishActivation(double value); //Swish
//--- @>872>4=K5 DC=:F89 0:B820F88
double LineDerivative(double value);
double SigmoidDerivative(double value);
double TanhDerivative(double value);
double LReLUDerivative(double value);
double SwishDerivative(double value, double input_value);
public:
CActivation(void);
~CActivation(void) {};
//---
void SetFunction(ENUM_ACTIVATION value, double param1 = 1, double param2 = 0);
ENUM_ACTIVATION GetFunction(double &params[]);
ENUM_ACTIVATION GetFunction(void) { return GetFunction(m_adParams); }
double Activation(double value);
bool Activation(CBufferDouble *buffer);
double Derivative(double value, double input_value = 1);
bool Derivative(CBufferDouble *outputs, CBufferDouble *gradient);
//---
virtual bool SetOpenCL(CMyOpenCL *opencl);
virtual bool BufferCreate(void) { return CBufferDouble::BufferCreate(m_cOpenCL);}
virtual bool BufferInit(uint count) { return CBufferDouble::BufferInit(count, 0.0);}
virtual bool BufferRead(void) { return CBufferDouble::BufferRead();}
virtual bool BufferFree(void) { return CBufferDouble::BufferFree();}
virtual int GetIndex(void) { return CBufferDouble::GetIndex();}
//--- methods for working with files
virtual bool Save(const int file_handle);
virtual bool Load(const int file_handle);
//--- method of identifying the object
virtual int Type(void) const { return defActivation; }
};
//+------------------------------------------------------------------+
//| >=AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CActivation::CActivation() : m_eFunction(ACT_SWISH)
{
m_adParams[0] = 1;
m_adParams[1] = 0;
}
//+------------------------------------------------------------------+
//| $C=:F8O 8=8F80;870F88 :;0AA0 |
//+------------------------------------------------------------------+
void CActivation::SetFunction(ENUM_ACTIVATION function, double param1 = 1.0, double param2 = 0.0)
{
m_eFunction = function;
switch(function)
{
case ACT_SOFTMAX:
m_adParams[0] = 1;
m_adParams[1] = 0;
break;
default:
m_adParams[0] = param1;
m_adParams[1] = param2;
break;
}
}
//+------------------------------------------------------------------+
//| $C=:F8O 2>72@0I05B 8A?>;L7C5<CN DC=:F8N 0:B820F88 |
//+------------------------------------------------------------------+
ENUM_ACTIVATION CActivation::GetFunction(double &params[])
{
if(ArrayCopy(params, m_adParams) <= 0)
return ACT_None;
return m_eFunction;
}
//+------------------------------------------------------------------+
//| #AB0=>2:0 8A?>;L7C5<>3> :>=B5:AB0 OpenCL |
//+------------------------------------------------------------------+
bool CActivation::SetOpenCL(CMyOpenCL *opencl)
{
if(m_cOpenCL != opencl)
{
if(CheckPointer(m_cOpenCL) != POINTER_INVALID)
delete m_cOpenCL;
m_cOpenCL = opencl;
}
//---
return(CheckPointer(m_cOpenCL) != POINTER_INVALID);
}
//+------------------------------------------------------------------+
//| 8A?5BG5@A:0O DC=:F8O >?@545;5=8O 2KG8A;5=8O DC=:F88 0:B820F88 |
//+------------------------------------------------------------------+
double CActivation::Activation(double value)
{
double result = 0;
switch(m_eFunction)
{
case ACT_LINE:
result = LineActivation(value);
break;
case ACT_SIGMOID:
result = SigmoidActivation(value);
break;
case ACT_TANH:
result = TanhActivation(value);
break;
case ACT_LReLU:
result = LReLUActivation(value);
break;
case ACT_SWISH:
result = SwishActivation(value);
break;
default:
result = value;
break;
}
//---
return result;
}
//+------------------------------------------------------------------+
//| 8A?5BG5@A:0O DC=:F8O >?@545;5=8O ?@>872>4=>9 DC=:F88 0:B820F88 |
//+------------------------------------------------------------------+
double CActivation::Derivative(double value, double input_value = 1)
{
double result = 1;
switch(m_eFunction)
{
case ACT_LINE:
result = LineDerivative(value);
break;
case ACT_SIGMOID:
case ACT_SOFTMAX:
result = SigmoidDerivative(value);
break;
case ACT_TANH:
result = TanhDerivative(value);
break;
case ACT_LReLU:
result = LReLUDerivative(value);
break;
case ACT_SWISH:
result = SwishDerivative(value, input_value);
break;
default:
result = 1;
break;
}
//---
return result;
}
//+------------------------------------------------------------------+
//| 8A?5BG5@A:0O DC=:F8O >?@545;5=8O 2KG8A;5=8O DC=:F88 0:B820F88 |
//+------------------------------------------------------------------+
bool CActivation::Activation(CBufferDouble *buffer)
{
if(CheckPointer(buffer) == POINTER_INVALID || buffer.m_data_total <= 0)
return false;
//---
switch(m_eFunction)
{
case ACT_None:
break;
case ACT_SOFTMAX:
{
double sum = 0;
for(int i = 0; i < buffer.m_data_total; i++)
sum += buffer.m_data[i] = MathExp(buffer.m_data[i]);
//--- >@<0;870F8O
for(int i = 0; i < buffer.m_data_total; i++)
buffer.m_data[i] /= sum;
}
break;
case ACT_SWISH:
if(!Reserve(buffer.m_data_total))
return false;
for(int i = 0; i < buffer.m_data_total; i++)
{
m_data[i] = buffer.m_data[i];
buffer.m_data[i] = SwishActivation(buffer.m_data[i]);
}
m_data_total = buffer.m_data_total;
break;
default:
for(int i = 0; i < buffer.m_data_total; i++)
buffer.m_data[i] = Activation(buffer.m_data[i]);
break;
}
//---
return true;
}
//+------------------------------------------------------------------+
//| 8A?5BG5@A:0O DC=:F8O >?@545;5=8O ?@>872>4=>9 DC=:F88 0:B820F88 |
//+------------------------------------------------------------------+
bool CActivation::Derivative(CBufferDouble *outputs, CBufferDouble *gradient)
{
if(CheckPointer(outputs) == POINTER_INVALID ||
CheckPointer(gradient) == POINTER_INVALID)
return false;
//---
if(CheckPointer(m_cOpenCL) == POINTER_INVALID)
{
switch(m_eFunction)
{
case ACT_None:
break;
case ACT_SWISH:
if(m_data_total < outputs.m_data_total)
return false;
for(int i = 0; i < outputs.m_data_total; i++)
gradient.m_data[i] *= SwishDerivative(outputs.m_data[i], m_data[i]);
break;
case ACT_SOFTMAX:
if(!AssignArray(gradient))
return false;
for(int i = 0; i < outputs.m_data_total; i++)
{
double grad = 0;
for(int j = 0; j < outputs.m_data_total; j++)
grad += outputs.m_data[j] * ((int)(i == j) - outputs.m_data[i]) * m_data[j];
gradient.m_data[i] = grad;
}
break;
default:
for(int i = 0; i < outputs.m_data_total; i++)
gradient.m_data[i] *= Derivative(outputs.m_data[i]);
break;
}
}
else
{
//--- 50:B820F8O 3@0485=B0 >H81:8
//--- !>740=85 1CD5@>2 40==KE
if(gradient.GetIndex() < 0)
if(!gradient.BufferCreate(m_cOpenCL))
return false;
if(outputs.GetIndex() < 0)
if(!outputs.BufferCreate(m_cOpenCL))
return false;
if(m_eFunction == ACT_SOFTMAX)
if(!AssignArray(gradient))
return false;
if(m_data_total != outputs.Total())
{
if(!BufferInit(outputs.Total()))
return false;
}
if(!BufferCreate())
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_DeActivateGradient, def_deactgr_sums, GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_DeActivateGradient, def_deactgr_outputs, outputs.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_DeActivateGradient, def_deactgr_gradients, gradient.GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_DeActivateGradient, def_deactgr_outputs_total, outputs.Total()))
return false;
if(!m_cOpenCL.SetArgument(def_k_DeActivateGradient, def_deactgr_activation, (int)m_eFunction))
return false;
if(!m_cOpenCL.SetArgument(def_k_DeActivateGradient, def_deactgr_act_param_a, m_adParams[0]))
return false;
if(!m_cOpenCL.SetArgument(def_k_DeActivateGradient, def_deactgr_act_param_b, m_adParams[1]))
return false;
//--- >AB0=>2:0 :5@=5;0 2 >G5@54L 2K?>;=5=8O
int s = outputs.Total();
int d = s % 4;
s = (s - d) / 4 + (d > 0 ? 1 : 0);
int NDRange[] = {s};
int off_set[] = {0};
if(!m_cOpenCL.Execute(def_k_DeActivateGradient, 1, off_set, NDRange))
return false;
//--- >;CG5=85 @57C;LB0B>2 >?5@0F89
if(!gradient.BufferRead())
return false;
BufferFree();
outputs.BufferFree();
gradient.BufferFree();
}
//---
return true;
}
//+------------------------------------------------------------------+
//| 8=59=0O DC=:F8O 0:B820F88 |
//| 0@0<5B@K 'value' 725H5==0O AC<<0 8AE>4=KE 40==KE 4;O 0:820F88 |
//| 'm_adParams[0]' :>MDD8F85=B =0:;>=0 ;8=88 |
//| 'm_adParams[1]' - 25@B8:0;L=K9 A4283 ;8=88 |
//+------------------------------------------------------------------+
double CActivation::LineActivation(double value)
{
return (m_adParams[0] * value + m_adParams[1]);
}
//+------------------------------------------------------------------+
//| @>872>4=0O ;8=59=>9 DC=:F88 0:B820F88 2>72@0I05B Parameter[0] |
//+------------------------------------------------------------------+
double CActivation::LineDerivative(double value)
{
return m_adParams[0];
}
//+------------------------------------------------------------------+
//| !83<>84=0O DC=:F8O 0:B802F88 |
//| 0@0<5B@K 'value' 725H5==0O AC<<0 8AE>4=KE 40==KE 4;O 0:820F88 |
//| 'm_adParams[0]' >?@545;O5B 480?07>= 7=0G5=89 DC=:F88 0:B82F88 |
//| >B '0' 4> 'm_adParams[0]' |
//| 'm_adParams[1]' - 25@B8:0;L=K9 A4283 7=0G5=8O DC=:F88 |
//+------------------------------------------------------------------+
double CActivation::SigmoidActivation(double value)
{
return (m_adParams[0] / (1 + exp(-value)) - m_adParams[1]);
}
//+------------------------------------------------------------------+
//| @>8872>4=0O A83<>84=>9 DC=:F88 0:B820F88 |
//| 0@0<5B@K 'value' B5:CI55 7=0G5=885 DC=:F88 0:B820F88 |
//| 'm_adParams[0]' >?@545;O5B 480?07>= 7=0G5=89 DC=:F88 0:B82F88 |
//| >B '0' 4> 'm_adParams[0]' |
//| 'm_adParams[1]' - 25@B8:0;L=K9 A4283 7=0G5=8O DC=:F88 |
//+------------------------------------------------------------------+
double CActivation::SigmoidDerivative(double value)
{
double z = MathMax(MathMin(m_adParams[0], value + m_adParams[1]), 0);
return (z * (1 - z / m_adParams[0]));
}
//+------------------------------------------------------------------+
//| TANH |
//| 0@0<5B@K 'value' 725H5==0O AC<<0 8AE>4=KE 40==KE 4;O 0:820F88 |
//+------------------------------------------------------------------+
double CActivation::TanhActivation(double value)
{
return MathTanh(value);
}
//+------------------------------------------------------------------+
//| @>872>4=0O TANH |
//| 0@0<5B@K 'value' B5:CI55 7=0G5=885 DC=:F88 0:B820F88 |
//+------------------------------------------------------------------+
double CActivation::TanhDerivative(double value)
{
return (1 - MathPow(value, 2));
}
//+------------------------------------------------------------------+
//| LReLU |
//| 0@0<5B@K 'value' 725H5==0O AC<<0 8AE>4=KE 40==KE 4;O 0:820F88 |
//| 'm_adParams[0]' :>MDD8F85=B CB5G:8 |
//+------------------------------------------------------------------+
double CActivation::LReLUActivation(double value)
{
return (value > 0 ? value : m_adParams[0] * value);
}
//+------------------------------------------------------------------+
//| @>872>4=0O LReLU |
//| 0@0<5B@K 'value' B5:CI55 7=0G5=885 DC=:F88 0:B820F88 |
//| 'm_adParams[0]' :>MDD8F85=B CB5G:8 |
//+------------------------------------------------------------------+
double CActivation::LReLUDerivative(double value)
{
return (value > 0 ? 1 : m_adParams[0]);
}
//+------------------------------------------------------------------+
//| Swish |
//| 0@0<5B@K 'value' 725H5==0O AC<<0 8AE>4=KE 40==KE 4;O 0:820F88 |
//| 'm_adParams[0]' :>MDD8F85=B =5 ;8=59=>AB8B DC=:F88 |
//+------------------------------------------------------------------+
double CActivation::SwishActivation(double value)
{
return value / (1 + exp(-value * m_adParams[0]));
}
//+------------------------------------------------------------------+
//| @>872>4=0O Swish |
//| 0@0<5B@K 'value' B5:CI55 7=0G5=885 DC=:F88 0:B820F88 |
//| 'value_input' 725H5==0O AC<<0 8AE>4=KE 40==KE 4;O 0:820F88 |
//| 'm_adParams[0]' :>MDD8F85=B =5 ;8=59=>AB8B DC=:F88 |
//+------------------------------------------------------------------+
double CActivation::SwishDerivative(double value, double input_value)
{
if(input_value == 0)
return 0.5;
//---
double by = m_adParams[0] * value;
return (by + (value / input_value * (1 - by)));
}
//+------------------------------------------------------------------+
//| 5B>4 A>E@0=5=8O :;0AA0 |
//+------------------------------------------------------------------+
bool CActivation::Save(const int file_handle)
{
if(file_handle == INVALID_HANDLE)
return false;
if(FileWriteInteger(file_handle, (int)m_eFunction) <= 0 ||
FileWriteArray(file_handle, m_adParams) <= 0)
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 2>AAB0=>2;5=8O M;5<5=B>2 :;0AA0 ?> @0=55 A>E@0=Q==K< 40==K<|
//+------------------------------------------------------------------+
bool CActivation::Load(const int file_handle)
{
if(file_handle == INVALID_HANDLE)
return false;
m_eFunction = (ENUM_ACTIVATION)FileReadInteger(file_handle);
if(FileReadArray(file_handle, m_adParams) <= 0)
return false;
//---
return true;
}
//+------------------------------------------------------------------+