NeuroNetworksBook/Include/realization/neuronbatchnorm.mqh

331 lines
29 KiB
MQL5
Raw Permalink Normal View History

2025-05-30 16:12:34 +02:00
<EFBFBD><EFBFBD>//+------------------------------------------------------------------+
//| NeuronBatchNorm.mqh |
//| Copyright 2021, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2021, MetaQuotes Ltd."
#property link "https://www.mql5.com"
//+------------------------------------------------------------------+
//| >4:;NG05< 181;8>B5:8 |
//+------------------------------------------------------------------+
#include "neuronbase.mqh"
//+------------------------------------------------------------------+
//| Class CNeuronBatchNorm |
//| 07=0G5=85: ;0AA ?0:5B=>9 =>@<0;870F88 |
//+------------------------------------------------------------------+
class CNeuronBatchNorm : public CNeuronBase
{
protected:
CBufferDouble *m_cBatchOptions;
uint m_iBatchSize; // 07<5@ ?0:5B0
public:
CNeuronBatchNorm(void);
~CNeuronBatchNorm(void);
//---
virtual bool Init(CLayerDescription *description);
virtual bool FeedForward(CNeuronBase *prevLayer);
virtual bool CalcHiddenGradient(CNeuronBase *prevLayer);
virtual bool CalcDeltaWeights(CNeuronBase *prevLayer);
//--- methods for working with files
virtual bool Save(const int file_handle);
virtual bool Load(const int file_handle);
//--- method of identifying the object
virtual int Type(void) const { return(defNeuronBatchNorm); }
};
//+------------------------------------------------------------------+
//| >=AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronBatchNorm::CNeuronBatchNorm(void) : m_iBatchSize(1)
{
m_cBatchOptions = new CBufferDouble;
}
//+------------------------------------------------------------------+
//| 5AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronBatchNorm::~CNeuronBatchNorm(void)
{
if(m_cBatchOptions)
delete m_cBatchOptions;
}
//+------------------------------------------------------------------+
//| 5B>4 8=8F80;870F88 :;0AA0 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::Init(CLayerDescription *description)
{
if(CheckPointer(description) == POINTER_INVALID ||
description.window != description.count)
return false;
description.window = 1;
if(!CNeuronBase::Init(description))
return false;
//--- =8F80;878@C5< 1CD5@ >1CG05<KE ?0@0<5B@>2
if(!m_cWeights.m_mMatrix.Fill(0))
return false;
//--- =8F80;878@C5< 1CD5@ ?0@0<5B@>2 =>@<0;870F88
if(CheckPointer(m_cBatchOptions) == POINTER_INVALID)
{
m_cBatchOptions = new CBufferDouble();
if(CheckPointer(m_cBatchOptions) == POINTER_INVALID)
return false;
}
if(!m_cBatchOptions.BufferInit(description.count, 3, 0))
return false;
m_iBatchSize = description.batch;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 ?@O<>3> ?@>E>40 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::FeedForward(CNeuronBase *prevLayer)
{
//--- ;>: :>=B@>;59
if(!prevLayer || !prevLayer.GetOutputs() || !m_cOutputs || !m_cBatchOptions || !m_cWeights || !m_cActivation)
return false;
//--- @>25@:0 @07<5@0 ?0:5B0 =>@<0;870F88
if(m_iBatchSize <= 1)
{
m_cOutputs.m_mMatrix = prevLayer.GetOutputs().m_mMatrix;
if(m_cOpenCL && !m_cOutputs.BufferWrite())
return false;
if(!m_cActivation.Activation(m_cOutputs))
return false;
return true;
}
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
if(!m_cOpenCL)
{
VECTOR inputs = prevLayer.GetOutputs().m_mMatrix.Row(0);
VECTOR mean = (m_cBatchOptions.m_mMatrix.Col(0) * ((double)m_iBatchSize - 1.0) + inputs) / (double)m_iBatchSize;
VECTOR delt = inputs - mean;
VECTOR variance = (m_cBatchOptions.m_mMatrix.Col(1) * ((double)m_iBatchSize - 1.0) + delt * delt) / (double)m_iBatchSize;
VECTOR std = variance;
for(uint r = 0; r < std.Size(); r++)
std[r] = (std[r] > 0 ? sqrt(std[r]) : 1e-8);
VECTOR nx = delt / std;
VECTOR res = m_cWeights.m_mMatrix.Col(0) * nx + m_cWeights.m_mMatrix.Col(1);
if(!m_cOutputs.m_mMatrix.Row(res, 0) ||
!m_cBatchOptions.m_mMatrix.Col(mean, 0) ||
!m_cBatchOptions.m_mMatrix.Col(variance, 1) ||
!m_cBatchOptions.m_mMatrix.Col(nx, 2))
return false;
}
else // ;>: OpenCL
{
//--- !>740=85 1CD5@>2 40==KE
CBufferDouble *inputs = prevLayer.GetOutputs();
if(inputs.GetIndex() < 0)
return false;
if(m_cBatchOptions.GetIndex() < 0)
return false;
if(m_cWeights.GetIndex() < 0)
return false;
if(m_cOutputs.GetIndex() < 0)
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormFeedForward, def_bnff_inputs, inputs.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormFeedForward, def_bnff_weights, m_cWeights.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormFeedForward, def_bnff_options, m_cBatchOptions.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormFeedForward, def_bnff_outputs, m_cOutputs.GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_BatchNormFeedForward, def_bnff_total, m_cOutputs.Total()))
return false;
if(!m_cOpenCL.SetArgument(def_k_BatchNormFeedForward, def_bnff_batch, m_iBatchSize))
return false;
//--- >AB0=>2:0 2 >G5@54L 2K?>;=5=8O
uint off_set[] = {0};
uint NDRange[] = { (int)(m_cOutputs.Total() + 3) / 4 };
if(!m_cOpenCL.Execute(def_k_BatchNormFeedForward, 1, off_set, NDRange))
return false;
//--- >;CG5=85 @57C;LB0B>2
if(!m_cOutputs.BufferRead())
return false;
}
//---
if(!m_cActivation.Activation(m_cOutputs))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 @0A?@545;5=85 3@0485=B0 G5@57 A:@KBK9 A;>9 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::CalcHiddenGradient(CNeuronBase *prevLayer)
{
//--- ;>: :>=B@>;59
if(!prevLayer || !prevLayer.GetOutputs() || !prevLayer.GetGradients() || !m_cActivation || !m_cBatchOptions || !m_cWeights)
return false;
//--- >@@5:B8@>2:0 3@0485=B0 >H81:8 =0 ?@>872>4=C DC=:F88 0:B820F88
if(!m_cActivation.Derivative(m_cOutputs, m_cGradients))
return false;
//--- @>25@:0 @07<5@0 ?0:5B0 =>@<0;870F88
if(m_iBatchSize <= 1)
{
prevLayer.GetGradients().m_mMatrix = m_cGradients.m_mMatrix;
if(m_cOpenCL && !prevLayer.GetGradients().BufferWrite())
return false;
return true;
}
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
if(!m_cOpenCL)
{
VECTOR inputs = prevLayer.GetOutputs().m_mMatrix.Row(0);
CBufferDouble *inputs_grad = prevLayer.GetGradients();
uint total = m_cOutputs.Total();
for(uint i = 0; i < total; i++)
{
double gnx = m_cGradients.m_mMatrix[1, i] * m_cWeights.m_mMatrix[i, 0];
double temp = 1 / MathSqrt(m_cBatchOptions.m_mMatrix[i, 1] + 1e-8);
double gvar = (inputs[i] - m_cBatchOptions.m_mMatrix[i, 0]) / (-2 * pow(m_cBatchOptions.m_mMatrix[i, 1] + 1.0e-8, 3.0 / 2.0)) * gnx;
double gmu = (-temp) * gnx - gvar * 2 * (inputs[i] - m_cBatchOptions.m_mMatrix[i, 0]) / (double)m_iBatchSize;
double gx = temp * gnx + gmu / (double)m_iBatchSize + gvar * 2 * (inputs[i] - m_cBatchOptions.m_mMatrix[i, 0]) / (double)m_iBatchSize;
if(!inputs_grad.m_mMatrix.Flat(i, gx))
return false;
}
}
else // ;>: OpenCL
{
//--- !>740=85 1CD5@>2 40==KE
CBufferDouble *inputs = prevLayer.GetOutputs();
CBufferDouble *inputs_grad = prevLayer.GetGradients();
if(inputs.GetIndex() < 0 && !inputs.BufferCreate(m_cOpenCL))
return false;
if(m_cBatchOptions.GetIndex() < 0 && !m_cBatchOptions.BufferCreate(m_cOpenCL))
return false;
if(m_cWeights.GetIndex() < 0 && !m_cWeights.BufferCreate(m_cOpenCL))
return false;
if(m_cOutputs.GetIndex() < 0 && !m_cOutputs.BufferCreate(m_cOpenCL))
return false;
if(m_cGradients.GetIndex() < 0 && !m_cGradients.BufferCreate(m_cOpenCL))
return false;
if(inputs_grad.GetIndex() < 0 && !inputs_grad.BufferCreate(m_cOpenCL))
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_inputs, inputs.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_weights, m_cWeights.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_options, m_cBatchOptions.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_gradient, m_cGradients.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcHiddenGradient, def_bnhgr_gradient_inputs, inputs_grad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_BatchNormCalcHiddenGradient, def_bnhgr_total, m_cOutputs.Total()))
return false;
if(!m_cOpenCL.SetArgument(def_k_BatchNormCalcHiddenGradient, def_bnhgr_batch, m_iBatchSize))
return false;
//--- >AB0=>2:0 2 >G5@54L 2K?>;=5=8O
int off_set[] = {0};
int NDRange[] = { (int)(m_cOutputs.Total() + 3) / 4 };
if(!m_cOpenCL.Execute(def_k_BatchNormCalcHiddenGradient, 1, off_set, NDRange))
return false;
//--- >;CG5=85 @57C;LB0B>2
if(!inputs_grad.BufferRead())
return false;
//--- G8AB:0 ?0<OB8 :>=B5:AB0 OpenCL
inputs.BufferFree();
m_cWeights.BufferFree();
m_cBatchOptions.BufferFree();
m_cGradients.BufferFree();
}
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 @0A?@545;5=8O 3@0485=B0 4> C@>2=O <0B@8FK 25A>2 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::CalcDeltaWeights(CNeuronBase *prevLayer)
{
//--- ;>: :>=B@>;59
if(!m_cBatchOptions || !m_cGradients || !m_cDeltaWeights)
return false;
//--- @>25@:0 @07<5@0 ?0:5B0 =>@<0;870F88
if(m_iBatchSize <= 1)
return true;
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
if(!m_cOpenCL)
{
VECTOR grad = m_cGradients.m_mMatrix.Row(0);
VECTOR delta = m_cBatchOptions.m_mMatrix.Col(2) * grad + m_cDeltaWeights.m_mMatrix.Col(0);
if(!m_cDeltaWeights.m_mMatrix.Col(delta, 0))
return false;
if(!m_cDeltaWeights.m_mMatrix.Col(grad + m_cDeltaWeights.m_mMatrix.Col(1), 1))
return false;
}
else
{
//--- !>740=85 1CD5@>2 40==KE
if(m_cBatchOptions.GetIndex() < 0 && !m_cBatchOptions.BufferCreate(m_cOpenCL))
return false;
if(m_cGradients.GetIndex() < 0 && !m_cGradients.BufferCreate(m_cOpenCL))
return false;
if(m_cDeltaWeights.GetIndex() < 0 && !m_cDeltaWeights.BufferCreate(m_cOpenCL))
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcDeltaWeights, def_bndelt_delta_weights, m_cDeltaWeights.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcDeltaWeights, def_bndelt_options, m_cBatchOptions.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_BatchNormCalcDeltaWeights, def_bndelt_gradient, m_cGradients.GetIndex()))
return false;
//--- >AB0=>2:0 2 >G5@54L 2K?>;=5=8O
int off_set[] = {0};
int NDRange[] = {(int)m_cOutputs.Total()};
if(!m_cOpenCL.Execute(def_k_BatchNormCalcDeltaWeights, 1, off_set, NDRange))
return false;
//--- >;CG5=85 @57C;LB0B>2
if(!m_cDeltaWeights.BufferRead())
return false;
//--- G8AB:0 ?0<OB8 :>=B5:AB0
m_cWeights.BufferFree();
m_cBatchOptions.BufferFree();
m_cDeltaWeights.BufferFree();
}
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 A>E@0=5=8O >1J5:B>2 :;0AA0 2 D09; |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::Save(const int file_handle)
{
//--- ;>: :>=B@>;59
if(!m_cBatchOptions)
return false;
//--- K7K205< <5B>4 @>48B5;L:>3> :;0AA0
if(!CNeuronBase::Save(file_handle))
return false;
//--- !>E@0=O5< @07<5@ ?0:5B0 =>@<0;870F88
if(FileWriteInteger(file_handle, m_iBatchSize) <= 0)
return false;
//--- !>E@0=5=85 ?0@0<5B@>2 =>@<0;870F88
if(!m_cBatchOptions.Save(file_handle))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 2>AB0=>2;5=8O :;0AA0 87 40==KE 2 D09;5 |
//+------------------------------------------------------------------+
bool CNeuronBatchNorm::Load(const int file_handle)
{
//--- K7K205< <5B>4 @>48B5;L:>3> :;0AA0
if(!CNeuronBase::Load(file_handle))
return false;
m_iBatchSize = FileReadInteger(file_handle);
//--- =8F80;878@C5< 48=0<8G5A:89 <0AA82 ?0@0<5B@>2 >?B8<870F88
if(!m_cBatchOptions)
if(!(m_cBatchOptions = new CBufferDouble()))
return false;
if(!m_cBatchOptions.Load(file_handle))
return false;
//---
return true;
}
//+------------------------------------------------------------------+