NeuroNetworksBook/Include/realization/neuronattention.mqh

746 lines
53 KiB
MQL5
Raw Permalink Normal View History

2025-05-30 16:12:34 +02:00
<EFBFBD><EFBFBD>//+------------------------------------------------------------------+
//| NeuronAttention.mqh |
//| Copyright 2021, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2021, MetaQuotes Ltd."
#property link "https://www.mql5.com"
//+------------------------------------------------------------------+
//| >4:;NG05< 181;8>B5:8 |
//+------------------------------------------------------------------+
#ifndef Defines
#include "defines.mqh"
#endif
#include "neuronconv.mqh"
#include <Math\Stat\Math.mqh>
//+------------------------------------------------------------------+
//| Class CNeuronAttention |
//| 07=0G5=85: ;0AA 1;>:0 2=8<0=8O Self-Attention |
//+------------------------------------------------------------------+
class CNeuronAttention : public CNeuronBase
{
protected:
CNeuronConv *m_cQuerys;
CNeuronConv *m_cKeys;
CNeuronConv *m_cValues;
CBufferDouble *m_cScores;
CBufferDouble *m_cScoreGrad;
CBufferDouble *m_cScoreTemp;
CNeuronBase *m_cAttentionOut;
CNeuronConv *m_cFF1;
CNeuronConv *m_cFF2;
//---
int m_iWindow;
int m_iUnits;
int m_iKeysSize;
double m_dStd[2];
public:
CNeuronAttention(void);
~CNeuronAttention(void);
//---
virtual bool Init(CLayerDescription *desc);
virtual bool SetOpenCL(CMyOpenCL *opencl);
virtual bool FeedForward(CNeuronBase *prevLayer);
virtual bool CalcHiddenGradient(CNeuronBase *prevLayer);
virtual bool CalcDeltaWeights(CNeuronBase *prevLayer);
virtual bool UpdateWeights(int batch_size, double learningRate,
double &Beta[], double &Lambda[]);
//--- 5B>4K @01>BK A D09;0<8
virtual bool Save(const int file_handle);
virtual bool Load(const int file_handle);
//--- 5B>4 845=B8D8:0F88 >1J5:B0
virtual int Type(void) const { return(defNeuronAttention); }
};
//+------------------------------------------------------------------+
//| >=AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronAttention::CNeuronAttention(void) : m_iWindow(1),
m_iUnits(0),
m_iKeysSize(1)
{
m_cQuerys = new CNeuronConv;
m_cKeys = new CNeuronConv;
m_cValues = new CNeuronConv;
m_cScores = new CBufferDouble;
m_cAttentionOut = new CNeuronBase();
m_cFF1 = new CNeuronConv;
m_cFF2 = new CNeuronConv;
ArrayInitialize(m_dStd, 1);
}
//+------------------------------------------------------------------+
//| 5AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronAttention::~CNeuronAttention(void)
{
if(m_cQuerys)
delete m_cQuerys;
if(m_cKeys)
delete m_cKeys;
if(m_cValues)
delete m_cValues;
if(m_cScores)
delete m_cScores;
if(m_cScoreGrad)
delete m_cScoreGrad;
if(m_cScoreTemp)
delete m_cScoreTemp;
if(m_cAttentionOut)
delete m_cAttentionOut;
if(m_cFF1)
delete m_cFF1;
if(m_cFF2)
delete m_cFF2;
}
//+------------------------------------------------------------------+
//| 5B>4 8=8F80;870F88 :;0AA0 |
//+------------------------------------------------------------------+
bool CNeuronAttention::Init(CLayerDescription *desc)
{
//--- @>25@O5< 8AE>4=K5 40==K5
if(!desc || desc.type != Type() || desc.count <= 0 || desc.window <= 0 || desc.window_out <= 0)
return false;
//---
m_iWindow = desc.window;
m_iUnits = desc.count;
m_iKeysSize = desc.window_out;
//--- !>740Q< >?8A0=85 4;O 2=CB@5==8E =59@>==KE A;>Q2
CLayerDescription *temp = new CLayerDescription();
if(!temp)
return false;
temp.type = defNeuronConv;
temp.window = desc.window;
temp.window_out = m_iKeysSize;
temp.step = desc.window;
temp.count = desc.count;
temp.activation = ACT_None;
temp.activation_params[0] = 1;
temp.activation_params[1] = 0;
temp.optimization = desc.optimization;
//--- K7K205< <5B>4 8=8F80;870F88 @>48B5;LA:>3> :;0AA0
desc.count *= desc.window;
desc.window_out = 1;
desc.window = 0;
if(!CNeuronBase::Init(desc))
{
delete temp;
return false;
}
//--- =8F80;878@C5< Querys
if(!m_cQuerys)
{
m_cQuerys = new CNeuronConv();
if(!m_cQuerys)
{
delete temp;
return false;
}
}
if(!m_cQuerys.Init(temp))
{
delete temp;
return false;
}
m_cQuerys.SetTransposedOutput(true);
//--- =8F80;878@C5< Keys
if(!m_cKeys)
{
m_cKeys = new CNeuronConv();
if(!m_cKeys)
{
delete temp;
return false;
}
}
if(!m_cKeys.Init(temp))
{
delete temp;
return false;
}
m_cKeys.SetTransposedOutput(true);
//--- =8F80;878@C5< Values
if(!m_cValues)
{
m_cValues = new CNeuronConv();
if(!m_cValues)
{
delete temp;
return false;
}
}
temp.window_out = m_iWindow;
if(!m_cValues.Init(temp))
{
delete temp;
return false;
}
m_cValues.SetTransposedOutput(true);
//--- =8F80;878@C5< Scores
if(!m_cScores)
{
m_cScores = new CBufferDouble();
if(!m_cScores)
{
delete temp;
return false;
}
}
if(!m_cScores.BufferInit(temp.count, temp.count, 0))
{
delete temp;
return false;
}
//--- =8F80;878@C5< AttentionOut
if(!m_cAttentionOut)
{
if(!(m_cAttentionOut = new CNeuronBase()))
{
delete temp;
return false;
}
}
desc.type = defNeuronBase;
if(!m_cAttentionOut.Init(desc))
{
delete temp;
return false;
}
//--- =8F80;878@C5< FF1
if(!m_cFF1)
{
m_cFF1 = new CNeuronConv();
if(!m_cFF1)
{
delete temp;
return false;
}
}
temp.window_out *= 4;
temp.activation = ACT_SWISH;
temp.activation_params[0] = 1;
temp.activation_params[1] = 0;
if(!m_cFF1.Init(temp))
{
delete temp;
return false;
}
m_cFF1.SetTransposedOutput(true);
//--- =8F80;878@C5< FF2
if(!m_cFF2)
{
m_cFF2 = new CNeuronConv();
if(!m_cFF2)
{
delete temp;
return false;
}
}
temp.window = temp.window_out;
temp.window_out = temp.step;
temp.step = temp.window;
temp.activation = ACT_None;//desc.activation;
temp.activation_params[0] = 1;//desc.activation_params[0];
temp.activation_params[1] = 0;//desc.activation_params[1];
if(!m_cFF2.Init(temp))
{
delete temp;
return false;
}
m_cFF2.SetTransposedOutput(true);
delete temp;
//--- ;O 8A:;NG5=88O :>?8@>20=8O 1CD5@>2 >ACI5AB28< 8E ?>4<5=C
if(m_cOutputs)
delete m_cOutputs;
m_cOutputs = m_cFF2.GetOutputs();
if(m_cGradients)
delete m_cGradients;
m_cGradients = m_cFF2.GetGradients();
//--- 5@540Q< C:070B5;L =0 >1J5:B @01>BK A OpenCL 4> 2A5E 2=CB@5==8E >1J5:B>2
SetOpenCL(m_cOpenCL);
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 ?5@540G8 C:070B5;O =0 >1J5:B OpenCL 4> 2A5E 2=CB@5==8E |
//| >1J5:B>2 :;0AA0 |
//+------------------------------------------------------------------+
bool CNeuronAttention::SetOpenCL(CMyOpenCL *opencl)
{
CNeuronBase::SetOpenCL(opencl);
if(m_cQuerys)
m_cQuerys.SetOpenCL(m_cOpenCL);
if(m_cKeys)
m_cKeys.SetOpenCL(m_cOpenCL);
if(m_cValues)
m_cValues.SetOpenCL(m_cOpenCL);
if(m_cFF1)
m_cFF1.SetOpenCL(m_cOpenCL);
if(m_cFF2)
m_cFF2.SetOpenCL(m_cOpenCL);
//---
return(!!m_cOpenCL);
}
//+------------------------------------------------------------------+
//| 5B>4 ?@O<>3> ?@>E>40 |
//+------------------------------------------------------------------+
bool CNeuronAttention::FeedForward(CNeuronBase *prevLayer)
{
//--- @>25@O5< 0:BC0;L=>ABL 2A5E >1J5:B>2
if(!prevLayer || !prevLayer.GetOutputs() || !m_cQuerys || !m_cValues || !m_cKeys || !m_cFF1 || !m_cFF2)
return false;
//---
if(!m_cQuerys.FeedForward(prevLayer))
return false;
if(!m_cKeys.FeedForward(prevLayer))
return false;
if(!m_cValues.FeedForward(prevLayer))
return false;
//--- =8F80;878@C5< Scores
if(!m_cScores)
if(!(m_cScores = new CBufferDouble()))
return false;
//--- =8F80;878@C5< AttentionOut
if(!m_cAttentionOut)
{
if(!(m_cAttentionOut = new CNeuronBase()))
return false;
CLayerDescription *temp = new CLayerDescription();
if(!temp)
return false;
temp.type = defNeuronBase;
temp.count = (int)m_cOutputs.Total();
temp.window = 0;
if(!m_cAttentionOut.Init(temp))
{
delete temp;
return false;
}
delete temp;
}
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
uint total = m_cOutputs.Total();
MATRIX out;
if(!m_cOpenCL)
{
MATRIX querys = m_cQuerys.GetOutputs().m_mMatrix;
MATRIX keys = m_cKeys.GetOutputs().m_mMatrix;
if(!querys.Reshape(m_iUnits, m_iKeysSize) ||
!keys.Reshape(m_iUnits, m_iKeysSize))
return false;
//--- ?@545;O5< Scores
MATRIX scores = querys.MatMul(keys.Transpose()) / sqrt(m_iKeysSize);
for(int r = 0; r < m_iUnits; r++)
for(int c = 0; c < m_iUnits; c++)
scores[r, c] = MathExp(scores[r, c]);
VECTOR summs = scores.Sum(0);
//--- >@<0;87C5< Scores
for(int r = 0; r < m_iUnits; r++)
if(!scores.Row(scores.Row(r) / summs[r], r))
return false;
m_cScores.m_mMatrix = scores;
//--- KE>4 1;>:0 2=8<0=8O
MATRIX values = m_cValues.GetOutputs().m_mMatrix;
if(!values.Reshape(m_iUnits, m_iWindow))
return false;
out = scores.MatMul(values);
if(!out.Reshape(1, m_iUnits * m_iWindow))
return false;
//--- !C<<8@C5< A 8AE>4=K<8 40==K<8 8 =>@<0;87C5<
out += prevLayer.GetOutputs().m_mMatrix;
double mean = out.Mean();
m_dStd[0] = out.Std();
m_cAttentionOut.GetOutputs().m_mMatrix = (out - mean) / m_dStd[0];
}
else // ;>: OpenCL
{
//--- !>740=85 1CD5@>2 40==KE
if(m_cQuerys.GetOutputs().GetIndex() < 0)
return false;
if(m_cKeys.GetOutputs().GetIndex() < 0)
return false;
if(m_cValues.GetOutputs().GetIndex() < 0)
return false;
if(m_cScores.GetIndex() < 0)
return false;
if(m_cAttentionOut.GetOutputs().GetIndex() < 0)
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_keys, m_cKeys.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_outputs, m_cAttentionOut.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_querys, m_cQuerys.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_scores, m_cScores.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_values, m_cValues.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionFeedForward, def_attff_key_size, m_iKeysSize))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionFeedForward, def_attff_window, m_iWindow))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionFeedForward, def_attff_mask, 0))
return false;
//--- >AB0=>2:0 :5@=5;0 2 >G5@54L 2K?>;=5=8O
int off_set[] = {0, 0};
int NDRange[] = {m_iUnits, 1};
if(!m_cOpenCL.Execute(def_k_AttentionFeedForward, 2, off_set, NDRange))
return false;
//--- !G8BK20=885 @57C;LB0B>2 >?5@0F89
if(!m_cAttentionOut.GetOutputs().GetData(out, true))
return false;
//--- !C<<8@C5< A 8AE>4=K<8 40==K<8 8 =>@<0;87C5<
out += prevLayer.GetOutputs().m_mMatrix;
double mean = out.Mean();
m_dStd[0] = out.Std();
m_cAttentionOut.GetOutputs().m_mMatrix = out = (out - mean) / m_dStd[0];
if(!m_cAttentionOut.GetOutputs().BufferWrite())
return false;
}
//--- K7K205< <5B>4K ?@O<>3> ?@>E>40 A;>52 1;>:0 Feed Forward
if(!m_cFF1.FeedForward(m_cAttentionOut))
return false;
if(!m_cFF2.FeedForward(m_cFF1))
return false;
//--- !C<<8@C5< A 2KE>4>< 2=8<0=8O 8 =>@<0;87C5<
out += m_cOutputs.m_mMatrix;
double mean = out.Mean();
m_dStd[1] = out.Std();
m_cOutputs.m_mMatrix = (out - mean) / m_dStd[1];
if(m_cOpenCL && !m_cOutputs.BufferWrite())
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 ?@>2545=8O 3@0485=B0 G5@57 A:@KBK9 A;>9 |
//+------------------------------------------------------------------+
bool CNeuronAttention::CalcHiddenGradient(CNeuronBase *prevLayer)
{
//--- @>25@O5< 0:BC0;L=>ABL 2A5E >1J5:B>2
if(CheckPointer(m_cOutputs) == POINTER_INVALID ||
CheckPointer(m_cGradients) == POINTER_INVALID ||
CheckPointer(m_cScores) == POINTER_INVALID ||
CheckPointer(m_cFF2) == POINTER_INVALID ||
CheckPointer(m_cQuerys) == POINTER_INVALID ||
CheckPointer(m_cKeys) == POINTER_INVALID ||
CheckPointer(m_cValues) == POINTER_INVALID ||
m_cOutputs.Total() != m_cGradients.Total())
return false;
//--- >=B@>;L A=0G5=89 :>=AB0=B
if(m_dStd[1] != 0 && m_cGradients.Scaling(1 / m_dStd[1]) <= 0)
return false;
//--- @>2>;8< 3@0485=B G5@57 A;>O 1;>:0 Feed Forward
if(!m_cFF2.CalcHiddenGradient(m_cFF1))
return false;
if(!m_cFF1.CalcHiddenGradient(m_cAttentionOut))
return false;
CBufferDouble *attention_grad = m_cAttentionOut.GetGradients();
uint total = m_cOutputs.Total();
if(!attention_grad.SumArray(m_cGradients))
return false;
if(m_dStd[0] != 0 && attention_grad.Scaling(1 / m_dStd[0]) <= 0)
return false;
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
if(CheckPointer(m_cOpenCL) == POINTER_INVALID)
{
MATRIX values, gradients;
if(attention_grad.GetData(gradients, false) < (int)total)
return false;
if(!gradients.Reshape(m_iUnits, m_iWindow))
return false;
//--- 0A?@545;5=85 3@0485=B0 =0 Values
m_cValues.GetGradients().m_mMatrix = m_cScores.m_mMatrix.Transpose().MatMul(gradients);
//--- 0A?@545;5=85 3@0485=B0 =0 Querys 8 Keys
values = m_cValues.GetOutputs().m_mMatrix;
if(!values.Reshape(m_iUnits, m_iWindow))
return false;
gradients = gradients.MatMul(values.Transpose());
for(int r = 0; r < m_iUnits; r++)
{
MATRIX e;
if(!e.Init(m_iUnits, m_iUnits))
return false;
e.Identity();
for(int s = 0; s < m_iUnits; s++)
if(!e.Row(e.Row(s) - m_cScores.m_mMatrix.Row(r), s))
return false;
VECTOR g = (m_cScores.m_mMatrix.Row(r) * gradients.Row(r)).MatMul(e);
if(!gradients.Row(g / sqrt(m_iKeysSize), r))
return false;
}
values = m_cKeys.GetOutputs().m_mMatrix;
if(!values.Reshape(m_iUnits, m_iWindow))
return false;
m_cQuerys.GetGradients().m_mMatrix = gradients.MatMul(values);
values = m_cQuerys.GetOutputs().m_mMatrix;
if(!values.Reshape(m_iUnits, m_iWindow))
return false;
m_cKeys.GetGradients().m_mMatrix = gradients.Transpose().MatMul(values);
if(!m_cQuerys.GetGradients().m_mMatrix.Reshape(1, m_cQuerys.GetGradients().Total()) ||
!m_cKeys.GetGradients().m_mMatrix.Reshape(1, m_cKeys.GetGradients().Total()))
return false;
}
else // ;>: OpenCL
{
//--- !>740=85 1CD5@>2 40==KE
if(m_cValues.GetOutputs().GetIndex() < 0)
return false;
if(m_cValues.GetGradients().GetIndex() < 0)
return false;
if(m_cScores.GetIndex() < 0)
return false;
if(m_cAttentionOut.GetGradients().GetIndex() < 0)
return false;
if(m_cScoreGrad.GetIndex() < 0)
return false;
//---
if(m_cScoreTemp.GetIndex() < 0)
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_outputs_grad, m_cAttentionOut.GetGradients().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_scores, m_cScores.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_scores_grad, m_cScoreGrad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_scores_temp, m_cScoreTemp.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_values, m_cValues.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_values_grad, m_cValues.GetGradients().GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionScoreGradients, def_attscr_window, m_iWindow))
return false;
//--- >AB0=>2:0 :5@=5;0 2 >G5@54L 2K?>;=5=8O
int off_set[] = {0, 0};
int NDRange[] = {m_iUnits, 1};
if(!m_cOpenCL.Execute(def_k_AttentionScoreGradients, 2, off_set, NDRange))
return false;
//--- 03@C7:0 @57C;LB0B>2
if(!m_cValues.GetGradients().BufferRead())
return false;
//---
if(m_cQuerys.GetOutputs().GetIndex() < 0)
return false;
if(m_cQuerys.GetGradients().GetIndex() < 0)
return false;
if(m_cKeys.GetOutputs().GetIndex() < 0)
return false;
if(m_cKeys.GetGradients().GetIndex() < 0)
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_keys, m_cKeys.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_keys_grad, m_cKeys.GetGradients().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_querys, m_cQuerys.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_querys_grad, m_cQuerys.GetGradients().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_scores_grad, m_cScoreGrad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionHiddenGradients, def_atthgr_key_size, m_iKeysSize))
return false;
if(!m_cOpenCL.Execute(def_k_AttentionHiddenGradients, 2, off_set, NDRange))
return false;
//--- 03@C7:0 @57C;LB0B>2
if(!m_cQuerys.GetGradients().BufferRead())
return false;
//---
}
//--- 5@5=>A 3@0485=B0 >H81:8 =0 ?@54K4CI89 A;>9
if(!m_cValues.CalcHiddenGradient(prevLayer))
return false;
if(!attention_grad.SumArray(prevLayer.GetGradients()))
return false;
if(!m_cQuerys.CalcHiddenGradient(prevLayer))
return false;
if(!attention_grad.SumArray(prevLayer.GetGradients()))
return false;
if(!m_cKeys.CalcHiddenGradient(prevLayer))
return false;
if(!prevLayer.GetGradients().SumArray(attention_grad))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 @0A?@54545;5=8O 3@0485=B0 4> <0B@8F 25A>2 |
//+------------------------------------------------------------------+
bool CNeuronAttention::CalcDeltaWeights(CNeuronBase *prevLayer)
{
if(CheckPointer(m_cFF2) == POINTER_INVALID)
return false;
if(!m_cFF2.CalcDeltaWeights(m_cFF1))
return false;
if(!m_cFF1.CalcDeltaWeights(m_cAttentionOut))
return false;
if(CheckPointer(m_cQuerys) == POINTER_INVALID)
return false;
if(!m_cQuerys.CalcDeltaWeights(prevLayer))
return false;
if(CheckPointer(m_cKeys) == POINTER_INVALID)
return false;
if(!m_cKeys.CalcDeltaWeights(prevLayer))
return false;
if(CheckPointer(m_cValues) == POINTER_INVALID)
return false;
if(!m_cValues.CalcDeltaWeights(prevLayer))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 >1=>2;5=8O <0B@8F 25A>2KE :>MDD8F85=B>2 |
//+------------------------------------------------------------------+
bool CNeuronAttention::UpdateWeights(int batch_size, double learningRate, double &Beta[], double &Lambda[])
{
if(CheckPointer(m_cQuerys) == POINTER_INVALID)
return false;
if(!m_cQuerys.UpdateWeights(batch_size, learningRate, Beta, Lambda))
return false;
if(CheckPointer(m_cKeys) == POINTER_INVALID)
return false;
if(!m_cKeys.UpdateWeights(batch_size, learningRate, Beta, Lambda))
return false;
if(CheckPointer(m_cValues) == POINTER_INVALID)
return false;
if(!m_cValues.UpdateWeights(batch_size, learningRate, Beta, Lambda))
return false;
if(CheckPointer(m_cFF1) == POINTER_INVALID)
return false;
if(!m_cFF1.UpdateWeights(batch_size, learningRate, Beta, Lambda))
return false;
if(CheckPointer(m_cFF2) == POINTER_INVALID)
return false;
if(!m_cFF2.UpdateWeights(batch_size, learningRate, Beta, Lambda))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 70?8A8 A>45@68<>3> :;0AA0 2 D09; |
//+------------------------------------------------------------------+
bool CNeuronAttention::Save(const int file_handle)
{
if(!CNeuronBase::Save(file_handle))
return false;
if(CheckPointer(m_cQuerys) == POINTER_INVALID)
return false;
if(!m_cQuerys.Save(file_handle))
return false;
if(CheckPointer(m_cKeys) == POINTER_INVALID)
return false;
if(!m_cKeys.Save(file_handle))
return false;
if(CheckPointer(m_cValues) == POINTER_INVALID)
return false;
if(!m_cValues.Save(file_handle))
return false;
if(CheckPointer(m_cAttentionOut) == POINTER_INVALID)
return false;
if(!m_cAttentionOut.Save(file_handle))
return false;
if(CheckPointer(m_cFF1) == POINTER_INVALID)
return false;
if(!m_cFF1.Save(file_handle))
return false;
if(CheckPointer(m_cFF2) == POINTER_INVALID)
return false;
if(!m_cFF2.Save(file_handle))
return false;
if(FileWriteInteger(file_handle, m_iUnits) <= 0)
return false;
if(FileWriteInteger(file_handle, m_iWindow) <= 0)
return false;
if(FileWriteInteger(file_handle, m_iKeysSize) <= 0)
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 2>AAB0=>2;5=8O @01>B>A?>A>1=>AB8 :;0AA0 87 D09;0 |
//+------------------------------------------------------------------+
bool CNeuronAttention::Load(const int file_handle)
{
if(!CNeuronBase::Load(file_handle))
return false;
if(CheckPointer(m_cQuerys) == POINTER_INVALID)
{
m_cQuerys = new CNeuronConv();
if(CheckPointer(m_cQuerys) == POINTER_INVALID)
return false;
}
if(FileReadInteger(file_handle) != defNeuronConv || !m_cQuerys.Load(file_handle))
return false;
if(CheckPointer(m_cKeys) == POINTER_INVALID)
{
m_cKeys = new CNeuronConv();
if(CheckPointer(m_cKeys) == POINTER_INVALID)
return false;
}
if(FileReadInteger(file_handle) != defNeuronConv || !m_cKeys.Load(file_handle))
return false;
if(CheckPointer(m_cValues) == POINTER_INVALID)
{
m_cValues = new CNeuronConv();
if(CheckPointer(m_cValues) == POINTER_INVALID)
return false;
}
if(FileReadInteger(file_handle) != defNeuronConv || !m_cValues.Load(file_handle))
return false;
if(CheckPointer(m_cAttentionOut) == POINTER_INVALID)
{
m_cAttentionOut = new CNeuronBase();
if(CheckPointer(m_cAttentionOut) == POINTER_INVALID)
return false;
}
if(FileReadInteger(file_handle) != defNeuronBase || !m_cAttentionOut.Load(file_handle))
return false;
if(CheckPointer(m_cFF1) == POINTER_INVALID)
{
m_cFF1 = new CNeuronConv();
if(CheckPointer(m_cFF1) == POINTER_INVALID)
return false;
}
if(FileReadInteger(file_handle) != defNeuronConv || !m_cFF1.Load(file_handle))
return false;
if(CheckPointer(m_cFF2) == POINTER_INVALID)
{
m_cFF2 = new CNeuronConv();
if(CheckPointer(m_cFF2) == POINTER_INVALID)
return false;
}
if(FileReadInteger(file_handle) != defNeuronConv || !m_cFF2.Load(file_handle))
return false;
m_iUnits = FileReadInteger(file_handle);
int scores = m_iUnits * m_iUnits;
m_iWindow = FileReadInteger(file_handle);
m_iKeysSize = FileReadInteger(file_handle);
if(CheckPointer(m_cScores) == POINTER_INVALID)
{
m_cScores = new CBufferDouble();
if(CheckPointer(m_cScores) == POINTER_INVALID)
return false;
}
if(!m_cScores.BufferInit(scores, 0))
return false;
//---
if(m_cFF2.GetOutputs() != m_cOutputs)
{
if(CheckPointer(m_cOutputs) != POINTER_INVALID)
delete m_cOutputs;
m_cOutputs = m_cFF2.GetOutputs();
}
//---
if(m_cFF2.GetGradients() != m_cGradients)
{
if(CheckPointer(m_cGradients) != POINTER_INVALID)
delete m_cGradients;
m_cGradients = m_cFF2.GetGradients();
}
//---
return true;
}
//+------------------------------------------------------------------+