Original_NNB/MQL5/Include/NeuroNetworksBook/realization/neuronbatchnorm.mqh

375 lines
32 KiB
MQL5
Raw Permalink Normal View History

2025-05-30 16:15:14 +02:00
<EFBFBD><EFBFBD>//+------------------------------------------------------------------+
//| NeuronBatchNorm.mqh |
//| Copyright 2021, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2021, MetaQuotes Ltd."
#property link "https://www.mql5.com"
//+------------------------------------------------------------------+
//| >4:;NG05< 181;8>B5:8 |
//+------------------------------------------------------------------+
#include "neuronbase.mqh"
//+------------------------------------------------------------------+
//| Class CNeuronBatchNorm |
//| 07=0G5=85: ;0AA ?0:5B=>9 =>@<0;870F88 |
//+------------------------------------------------------------------+
class CNeuronBatchNorm : public CNeuronBase
{
protected:
CBufferDouble *m_cBatchOptions;
uint m_iBatchSize; // 07<5@ ?0:5B0
public:
CNeuronBatchNorm(void);
~CNeuronBatchNorm(void);
//---
virtual bool Init(CLayerDescription *description);
virtual bool FeedForward(CNeuronBase *prevLayer);
virtual bool CalcHiddenGradient(CNeuronBase *prevLayer);
virtual bool CalcDeltaWeights(CNeuronBase *prevLayer);
//--- methods for working with files
virtual bool Save(const int file_handle);
virtual bool Load(const int file_handle);
//--- method of identifying the object
virtual int Type(void) const { return(defNeuronBatchNorm); }
};
//+------------------------------------------------------------------+
//| >=AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronBatchNorm::CNeuronBatchNorm(void) : m_iBatchSize(1)
{
m_cBatchOptions = new CBufferDouble;
}
//+------------------------------------------------------------------+
//| 5AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronBatchNorm::~CNeuronBatchNorm(void)
{
if(CheckPointer(m_cBatchOptions) != POINTER_INVALID)
delete m_cBatchOptions;
}
//+------------------------------------------------------------------+
//| 5B>4 8=8F80;870F88 :;0AA0 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::Init(CLayerDescription *description)
{
if(CheckPointer(description) == POINTER_INVALID ||
description.window != description.count)
return false;
description.window = 1;
if(!CNeuronBase::Init(description))
return false;
//--- =8F80;878@C5< 1CD5@ >1CG05<KE ?0@0<5B@>2
if(!m_cWeights.BufferInit(m_cWeights.Total(), 0))
return false;
//--- =8F80;878@C5< 1CD5@ ?0@0<5B@>2 =>@<0;870F88
if(CheckPointer(m_cBatchOptions) == POINTER_INVALID)
{
m_cBatchOptions = new CBufferDouble();
if(CheckPointer(m_cBatchOptions) == POINTER_INVALID)
return false;
}
if(!m_cBatchOptions.BufferInit(description.count * 3, 0))
return false;
m_iBatchSize = description.batch;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 ?@O<>3> ?@>E>40 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::FeedForward(CNeuronBase *prevLayer)
{
//--- ;>: :>=B@>;59
if(CheckPointer(prevLayer) == POINTER_INVALID ||
CheckPointer(prevLayer.GetOutputs()) == POINTER_INVALID ||
CheckPointer(m_cOutputs) == POINTER_INVALID ||
CheckPointer(m_cBatchOptions) == POINTER_INVALID ||
CheckPointer(m_cWeights) == POINTER_INVALID ||
CheckPointer(m_cActivation) == POINTER_INVALID)
return false;
//--- @>25@:0 @07<5@0 ?0:5B0 =>@<0;870F88
if(m_iBatchSize <= 1)
{
if(!m_cOutputs.AssignArray(prevLayer.GetOutputs()))
return false;
if(!m_cActivation.Activation(m_cOutputs))
return false;
return true;
}
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
if(CheckPointer(m_cOpenCL) == POINTER_INVALID)
{
int total = m_cOutputs.Total();
CBufferDouble *inputs = prevLayer.GetOutputs();
for(int i = 0; i < total; i++)
{
int shift_options = i * 3;
int shift_weights = i * 2;
double mean = m_cBatchOptions[shift_options] * ((double)m_iBatchSize - 1.0) + inputs[i];
if(m_cBatchOptions[shift_options] != 0 && m_cBatchOptions[shift_options + 1] != 0)
mean /= (double)m_iBatchSize;
double delt = inputs[i] - mean;
double variance = m_cBatchOptions[shift_options + 1] * ((double)m_iBatchSize - 1.0) + MathPow(delt, 2);
if(m_cBatchOptions[shift_options + 1] > 0)
variance /= (double)m_iBatchSize;
double nx = delt / MathSqrt(variance + 1e-8);
//---
if(m_cWeights[shift_weights] == 0)
if(!m_cWeights.Update(shift_weights, 1))
return false;
//---
double res = m_cWeights[shift_weights] * nx + m_cWeights[shift_weights + 1];
if(!m_cOutputs.Update(i, res))
return false;
if(!m_cBatchOptions.Update(shift_options, mean) ||
!m_cBatchOptions.Update(shift_options + 1, variance) ||
!m_cBatchOptions.Update(shift_options + 2, nx))
return false;
}
}
else // ;>: OpenCL
{
//--- !>740=85 1CD5@>2 40==KE
CBufferDouble *inputs = prevLayer.GetOutputs();
if(inputs.GetIndex() < 0 && !inputs.BufferCreate(m_cOpenCL))
return false;
if(m_cBatchOptions.GetIndex() < 0 && !m_cBatchOptions.BufferCreate(m_cOpenCL))
return false;
if(m_cWeights.GetIndex() < 0 && !m_cWeights.BufferCreate(m_cOpenCL))
return false;
if(m_cOutputs.GetIndex() < 0 && !m_cOutputs.BufferCreate(m_cOpenCL))
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormFeedForward, def_bnff_inputs, inputs.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormFeedForward, def_bnff_weights, m_cWeights.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormFeedForward, def_bnff_options, m_cBatchOptions.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormFeedForward, def_bnff_outputs, m_cOutputs.GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_BatchNormFeedForward, def_bnff_total, m_cOutputs.Total()))
return false;
if(!m_cOpenCL.SetArgument(def_k_BatchNormFeedForward, def_bnff_batch, m_iBatchSize))
return false;
//--- >AB0=>2:0 2 >G5@54L 2K?>;=5=8O
int off_set[] = {0};
int s = m_cOutputs.Total();
int d = s % 4;
s = (s - d) / 4 + (d > 0 ? 1 : 0);
int NDRange[] = {s};
if(!m_cOpenCL.Execute(def_k_BatchNormFeedForward, 1, off_set, NDRange))
return false;
//--- >;CG5=85 @57C;LB0B>2
if(!m_cOutputs.BufferRead())
return false;
if(!m_cBatchOptions.BufferRead())
return false;
//--- G8AB:0 ?0<OB8 :5@=5;0
inputs.BufferFree();
m_cWeights.BufferFree();
m_cBatchOptions.BufferFree();
m_cOutputs.BufferFree();
}
//---
if(!m_cActivation.Activation(m_cOutputs))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 @0A?@545;5=85 3@0485=B0 G5@57 A:@KBK9 A;>9 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::CalcHiddenGradient(CNeuronBase *prevLayer)
{
//--- ;>: :>=B@>;59
if(CheckPointer(prevLayer) == POINTER_INVALID ||
CheckPointer(prevLayer.GetOutputs()) == POINTER_INVALID ||
CheckPointer(prevLayer.GetGradients()) == POINTER_INVALID ||
CheckPointer(m_cActivation) == POINTER_INVALID ||
CheckPointer(m_cBatchOptions) == POINTER_INVALID ||
CheckPointer(m_cWeights) == POINTER_INVALID)
return false;
//--- >@@5:B8@>2:0 3@0485=B0 >H81:8 =0 ?@>872>4=C DC=:F88 0:B820F88
if(!m_cActivation.Derivative(m_cOutputs, m_cGradients))
return false;
//--- @>25@:0 @07<5@0 ?0:5B0 =>@<0;870F88
if(m_iBatchSize <= 1)
return prevLayer.GetGradients().AssignArray(m_cGradients);
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
if(CheckPointer(m_cOpenCL) == POINTER_INVALID)
{
CBufferDouble *inputs = prevLayer.GetOutputs();
CBufferDouble *inputs_grad = prevLayer.GetGradients();
int total = m_cOutputs.Total();
for(int i = 0; i < total; i++)
{
int shift_options = i * 3;
int shift_weights = i * 2;
double inp = inputs[i];
double gnx = m_cGradients[i] * m_cWeights[shift_weights];
double temp = 1 / MathSqrt(m_cBatchOptions[shift_options + 1] + 1e-8);
double gvar = (inp - m_cBatchOptions[shift_options]) / (-2 * pow(m_cBatchOptions[shift_options + 1] + 1.0e-8, 3.0 / 2.0)) * gnx;
double gmu = (-temp) * gnx - gvar * 2 * (inp - m_cBatchOptions[shift_options]) / m_iBatchSize;
double gx = temp * gnx + gmu / m_iBatchSize + gvar * 2 * (inp - m_cBatchOptions[shift_options]) / m_iBatchSize;
if(!inputs_grad.Update(i, gx))
return false;
}
}
else // ;>: OpenCL
{
//--- !>740=85 1CD5@>2 40==KE
CBufferDouble *inputs = prevLayer.GetOutputs();
CBufferDouble *inputs_grad = prevLayer.GetGradients();
if(inputs.GetIndex() < 0 && !inputs.BufferCreate(m_cOpenCL))
return false;
if(m_cBatchOptions.GetIndex() < 0 && !m_cBatchOptions.BufferCreate(m_cOpenCL))
return false;
if(m_cWeights.GetIndex() < 0 && !m_cWeights.BufferCreate(m_cOpenCL))
return false;
if(m_cOutputs.GetIndex() < 0 && !m_cOutputs.BufferCreate(m_cOpenCL))
return false;
if(m_cGradients.GetIndex() < 0 && !m_cGradients.BufferCreate(m_cOpenCL))
return false;
if(inputs_grad.GetIndex() < 0 && !inputs_grad.BufferCreate(m_cOpenCL))
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_inputs, inputs.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_weights, m_cWeights.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_options, m_cBatchOptions.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_gradient, m_cGradients.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_gradient_inputs, inputs_grad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_BatchNormCalcHiddenGradient, def_bnhgr_total, m_cOutputs.Total()))
return false;
if(!m_cOpenCL.SetArgument(def_k_BatchNormCalcHiddenGradient, def_bnhgr_batch, m_iBatchSize))
return false;
//--- >AB0=>2:0 2 >G5@54L 2K?>;=5=8O
int off_set[] = {0};
int s = m_cOutputs.Total();
int d = s % 4;
s = (s - d) / 4 + (d > 0 ? 1 : 0);
int NDRange[] = {s};
if(!m_cOpenCL.Execute(def_k_BatchNormCalcHiddenGradient, 1, off_set, NDRange))
return false;
//--- >;CG5=85 @57C;LB0B>2
if(!inputs_grad.BufferRead())
return false;
//--- G8AB:0 ?0<OB8 :>=B5:AB0 OpenCL
inputs.BufferFree();
m_cWeights.BufferFree();
m_cBatchOptions.BufferFree();
m_cGradients.BufferFree();
}
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 @0A?@545;5=8O 3@0485=B0 4> C@>2=O <0B@8FK 25A>2 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::CalcDeltaWeights(CNeuronBase *prevLayer)
{
//--- ;>: :>=B@>;59
if(CheckPointer(m_cBatchOptions) == POINTER_INVALID ||
CheckPointer(m_cGradients) == POINTER_INVALID ||
CheckPointer(m_cDeltaWeights) == POINTER_INVALID)
return false;
//--- @>25@:0 @07<5@0 ?0:5B0 =>@<0;870F88
if(m_iBatchSize <= 1)
return true;
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
if(CheckPointer(m_cOpenCL) == POINTER_INVALID)
{
int total = m_cGradients.Total();
for(int i = 0; i < total; i++)
{
int shift_weights = i * 2;
int shift_options = i * 3;
double grad = m_cGradients[i];
double delta = m_cBatchOptions[shift_options + 2] * grad + m_cDeltaWeights[shift_weights];
if(!m_cDeltaWeights.Update(shift_weights, delta))
return false;
if(!m_cDeltaWeights.Update(shift_weights + 1, grad + m_cDeltaWeights[shift_weights + 1]))
return false;
}
}
else
{
//--- !>740=85 1CD5@>2 40==KE
if(m_cBatchOptions.GetIndex() < 0 && !m_cBatchOptions.BufferCreate(m_cOpenCL))
return false;
if(m_cGradients.GetIndex() < 0 && !m_cGradients.BufferCreate(m_cOpenCL))
return false;
if(m_cDeltaWeights.GetIndex() < 0 && !m_cDeltaWeights.BufferCreate(m_cOpenCL))
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcDeltaWeights, def_bndelt_delta_weights, m_cDeltaWeights.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcDeltaWeights, def_bndelt_options, m_cBatchOptions.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcDeltaWeights, def_bndelt_gradient, m_cGradients.GetIndex()))
return false;
//--- >AB0=>2:0 2 >G5@54L 2K?>;=5=8O
int off_set[] = {0};
int NDRange[] = {m_cOutputs.Total()};
if(!m_cOpenCL.Execute(def_k_BatchNormCalcDeltaWeights, 1, off_set, NDRange))
return false;
//--- >;CG5=85 @57C;LB0B>2
if(!m_cDeltaWeights.BufferRead())
return false;
//--- G8AB:0 ?0<OB8 :>=B5:AB0
m_cWeights.BufferFree();
m_cBatchOptions.BufferFree();
m_cDeltaWeights.BufferFree();
}
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 A>E@0=5=8O >1J5:B>2 :;0AA0 2 D09; |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::Save(const int file_handle)
{
//--- ;>: :>=B@>;59
if(CheckPointer(m_cBatchOptions) == POINTER_INVALID)
return false;
//--- K7K205< <5B>4 @>48B5;L:>3> :;0AA0
if(!CNeuronBase::Save(file_handle))
return false;
//--- !>E@0=O5< @07<5@ ?0:5B0 =>@<0;870F88
if(FileWriteInteger(file_handle, m_iBatchSize) <= 0)
return false;
//--- !>E@0=5=85 ?0@0<5B@>2 =>@<0;870F88
if(!m_cBatchOptions.Save(file_handle))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 2>AB0=>2;5=8O :;0AA0 87 40==KE 2 D09;5 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::Load(const int file_handle)
{
//--- K7K205< <5B>4 @>48B5;L:>3> :;0AA0
if(!CNeuronBase::Load(file_handle))
return false;
m_iBatchSize = FileReadInteger(file_handle);
//--- =8F80;878@C5< 48=0<8G5A:89 <0AA82 ?0@0<5B@>2 >?B8<870F88
if(CheckPointer(m_cBatchOptions) == POINTER_INVALID)
{
m_cBatchOptions = new CBufferDouble();
if(CheckPointer(m_cBatchOptions) == POINTER_INVALID)
return false;
}
if(!m_cBatchOptions.Load(file_handle))
return false;
//---
return true;
}
//+------------------------------------------------------------------+