Original_NNB/MQL5/Include/NeuroNetworksBook/realization/neuronmhattention.mqh

803 lines
61 KiB
MQL5
Raw Permalink Normal View History

2025-05-30 16:15:14 +02:00
<EFBFBD><EFBFBD>//+------------------------------------------------------------------+
//| NeuronMHAttention.mqh |
//| Copyright 2021, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2021, MetaQuotes Ltd."
#property link "https://www.mql5.com"
//+------------------------------------------------------------------+
//| >4:;NG05< 181;8>B5:8 |
//+------------------------------------------------------------------+
#include "neuronattention.mqh"
//+------------------------------------------------------------------+
//| Class CNeuronMHAttention |
//| 07=0G5=85: ;0AA >@30=870F88 @01>BK 1;>:0 <=>3>3>;>2>3> 2=8<0=8O|
//+------------------------------------------------------------------+
class CNeuronMHAttention : public CNeuronAttention
{
protected:
CNeuronConv *m_cW0;
int m_iHeads;
public:
CNeuronMHAttention(void);
~CNeuronMHAttention(void);
//---
virtual bool Init(CLayerDescription *description);
virtual bool SetOpenCL(CMyOpenCL *opencl);
virtual bool FeedForward(CNeuronBase *prevLayer);
virtual bool CalcHiddenGradient(CNeuronBase *prevLayer);
virtual bool CalcDeltaWeights(CNeuronBase *prevLayer);
virtual bool UpdateWeights(int batch_size, double learningRate,
double &Beta[], double &Lambda[]);
//--- 5B>4K @01>BK A D09;0<8
virtual bool Save(const int file_handle);
virtual bool Load(const int file_handle);
//--- 5B>4 845=B8D8:0F88 >1J5:B0
virtual int Type(void) const { return(defNeuronMHAttention); }
};
//+------------------------------------------------------------------+
//| >=AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronMHAttention::CNeuronMHAttention(void) : m_iHeads(8)
{
m_cW0 = new CNeuronConv();
}
//+------------------------------------------------------------------+
//| 5AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronMHAttention::~CNeuronMHAttention(void)
{
if(CheckPointer(m_cW0) != POINTER_INVALID)
delete m_cW0;
}
//+------------------------------------------------------------------+
//| 5B>4 8=8F80;870F88 :;0AA0 |
//+------------------------------------------------------------------+
bool CNeuronMHAttention::Init(CLayerDescription *description)
{
//--- @>25@O5< 8AE>4=K5 40==K5
if(CheckPointer(description) == POINTER_INVALID || description.type != Type() ||
description.count <= 0 || description.window <= 0 || description.window_out <= 0 ||
description.step <= 0)
return false;
//--- !>E@0=O5< :>=AB0=BK
m_iWindow = description.window;
m_iUnits = description.count;
m_iKeysSize = description.window_out;
m_iHeads = description.step;
//--- !>740Q< >?8A0=85 4;O 2=CB@5==8E =59@>==KE A;>Q2
CLayerDescription *temp = new CLayerDescription();
if(CheckPointer(temp) == POINTER_INVALID)
return false;
temp.type = defNeuronConv;
temp.window = m_iWindow;
temp.window_out = (int)(m_iKeysSize * m_iHeads);
temp.step = m_iWindow;
temp.count = m_iUnits;
temp.activation = ACT_None;
temp.activation_params[0] = 1;
temp.activation_params[1] = 0;
temp.optimization = description.optimization;
//--- K7K205< <5B>4 8=8F80;870F88 @>48B5;LA:>3> :;0AA0
description.count *= m_iWindow;
description.window_out = 1;
description.window = 0;
if(!CNeuronBase::Init(description))
{
delete temp;
return false;
}
//--- =8F80;878@C5< Querys
if(CheckPointer(m_cQuerys) == POINTER_INVALID)
{
m_cQuerys = new CNeuronConv();
if(CheckPointer(m_cQuerys) == POINTER_INVALID)
{
delete temp;
return false;
}
}
if(!m_cQuerys.Init(temp))
{
delete temp;
return false;
}
m_cQuerys.SetTransposedOutput(true);
//--- =8F80;878@C5< Keys
if(CheckPointer(m_cKeys) == POINTER_INVALID)
{
m_cKeys = new CNeuronConv();
if(CheckPointer(m_cKeys) == POINTER_INVALID)
{
delete temp;
return false;
}
}
if(!m_cKeys.Init(temp))
{
delete temp;
return false;
}
m_cKeys.SetTransposedOutput(true);
//--- =8F80;878@C5< Values
if(CheckPointer(m_cValues) == POINTER_INVALID)
{
m_cValues = new CNeuronConv();
if(CheckPointer(m_cValues) == POINTER_INVALID)
{
delete temp;
return false;
}
}
if(!m_cValues.Init(temp))
{
delete temp;
return false;
}
m_cValues.SetTransposedOutput(true);
//--- =8F80;878@C5< Scores
if(CheckPointer(m_cScores) == POINTER_INVALID)
{
m_cScores = new CBufferDouble();
if(CheckPointer(m_cScores) == POINTER_INVALID)
{
delete temp;
return false;
}
}
if(!m_cScores.BufferInit(m_iUnits * m_iUnits * m_iHeads, 0))
{
delete temp;
return false;
}
//--- =8F80;878@C5< AttentionOut
if(CheckPointer(m_cAttentionOut) == POINTER_INVALID)
{
m_cAttentionOut = new CNeuronBase();
if(CheckPointer(m_cAttentionOut) == POINTER_INVALID)
{
delete temp;
return false;
}
}
description.type = defNeuronBase;
description.count = (int)(m_iUnits * m_iKeysSize * m_iHeads);
if(!m_cAttentionOut.Init(description))
{
delete temp;
return false;
}
description.count = m_iUnits * m_iWindow;
//--- =8F80;878@C5< W0
if(CheckPointer(m_cW0) == POINTER_INVALID)
{
m_cW0 = new CNeuronConv();
if(CheckPointer(m_cW0) == POINTER_INVALID)
{
delete temp;
return false;
}
}
temp.window = (int)(m_iKeysSize * m_iHeads);
temp.step = temp.window;
temp.window_out = m_iWindow;
temp.activation = ACT_None;
temp.activation_params[0] = 1;
temp.activation_params[1] = 0;
if(!m_cW0.Init(temp))
{
delete temp;
return false;
}
m_cW0.SetTransposedOutput(true);
//--- =8F80;878@C5< FF1
if(CheckPointer(m_cFF1) == POINTER_INVALID)
{
m_cFF1 = new CNeuronConv();
if(CheckPointer(m_cFF1) == POINTER_INVALID)
{
delete temp;
return false;
}
}
temp.window = m_iWindow;
temp.step = temp.window;
temp.window_out = temp.window * 4;
temp.activation = ACT_SWISH;
temp.activation_params[0] = 1;
temp.activation_params[1] = 0;
if(!m_cFF1.Init(temp))
{
delete temp;
return false;
}
m_cFF1.SetTransposedOutput(true);
//--- =8F80;878@C5< FF2
if(CheckPointer(m_cFF2) == POINTER_INVALID)
{
m_cFF2 = new CNeuronConv();
if(CheckPointer(m_cFF2) == POINTER_INVALID)
{
delete temp;
return false;
}
}
temp.window = temp.window_out;
temp.window_out = temp.step;
temp.step = temp.window;
temp.activation = ACT_None;
temp.activation_params[0] = 1;
temp.activation_params[1] = 0;
if(!m_cFF2.Init(temp))
{
delete temp;
return false;
}
m_cFF2.SetTransposedOutput(true);
delete temp;
//--- ;O 8A:;NG5=88O :>?8@>20=8O 1CD5@>2 >ACI5AB28< 8E ?>4<5=C
if(CheckPointer(m_cOutputs) != POINTER_INVALID)
delete m_cOutputs;
m_cOutputs = m_cFF2.GetOutputs();
if(CheckPointer(m_cGradients) != POINTER_INVALID)
delete m_cGradients;
m_cGradients = m_cFF2.GetGradients();
//---
SetOpenCL(m_cOpenCL);
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 ?5@540G8 C:070B5;O =0 >1J5:B OpenCL 4> 2A5E |
//| 2=CB@5==8E >1J5:B>2 |
//+------------------------------------------------------------------+
bool CNeuronMHAttention::SetOpenCL(CMyOpenCL *opencl)
{
//--- K7>2 <5B>40 @>48B5;LA:>3> :;0AA0
CNeuronAttention::SetOpenCL(opencl);
//--- K7K205< 0=0;>38G=K9 <5B>4 4;O 2=CB@5==53> A;>O
if(CheckPointer(m_cW0) != POINTER_INVALID)
m_cW0.SetOpenCL(m_cOpenCL);
//---
return(CheckPointer(m_cOpenCL) != POINTER_INVALID);
}
//+------------------------------------------------------------------+
//| 5B>4 ?@O<>3> ?@>E>40 |
//+------------------------------------------------------------------+
bool CNeuronMHAttention::FeedForward(CNeuronBase *prevLayer)
{
//--- @>25@O5< 0:BC0;L=>ABL 2A5E >1J5:B>2
if(CheckPointer(prevLayer) == POINTER_INVALID ||
CheckPointer(prevLayer.GetOutputs()) == POINTER_INVALID ||
CheckPointer(m_cQuerys) == POINTER_INVALID ||
CheckPointer(m_cValues) == POINTER_INVALID ||
CheckPointer(m_cKeys) == POINTER_INVALID ||
CheckPointer(m_cW0) == POINTER_INVALID ||
CheckPointer(m_cFF1) == POINTER_INVALID ||
CheckPointer(m_cFF2) == POINTER_INVALID)
return false;
//---
if(!m_cQuerys.FeedForward(prevLayer))
return false;
if(!m_cKeys.FeedForward(prevLayer))
return false;
if(!m_cValues.FeedForward(prevLayer))
return false;
//--- =8F80;878@C5< Scores
if(CheckPointer(m_cScores) == POINTER_INVALID)
{
m_cScores = new CBufferDouble();
if(CheckPointer(m_cScores) == POINTER_INVALID)
return false;
}
//--- =8F80;878@C5< AttentionOut
int total = m_cValues.GetOutputs().Total();
if(CheckPointer(m_cAttentionOut) == POINTER_INVALID)
{
m_cAttentionOut = new CNeuronBase();
if(CheckPointer(m_cAttentionOut) == POINTER_INVALID)
return false;
}
if(CheckPointer(m_cAttentionOut.GetOutputs()) == POINTER_INVALID ||
m_cAttentionOut.GetOutputs().Total() != total)
{
CLayerDescription *temp = new CLayerDescription();
if(CheckPointer(temp) == POINTER_INVALID)
return false;
temp.type = defNeuronBase;
temp.count = total;
temp.window = 0;
if(!m_cAttentionOut.Init(temp))
{
delete temp;
return false;
}
delete temp;
}
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
double summs[];
if(CheckPointer(m_cOpenCL) == POINTER_INVALID)
{
CBufferDouble *querys = m_cQuerys.GetOutputs();
CBufferDouble *keys = m_cKeys.GetOutputs();
//--- ?@545;O5< Scores
double scores[];
if(ArrayResize(scores, m_iUnits * m_iUnits * m_iHeads) <= 0 ||
ArrayResize(summs, m_iUnits) <= 0)
return false;
for(int head = 0; head < m_iHeads; head++)
{
for(int query = 0; query < m_iUnits; query++)
{
int shift_query = query * m_iKeysSize * m_iHeads + head * m_iKeysSize;
summs[query] = 0;
for(int key = 0; key < m_iUnits; key++)
{
int shift_key = key * m_iKeysSize * m_iHeads + head * m_iKeysSize;
int shift_score = query * m_iUnits * m_iHeads + head * m_iUnits + key;
double score = 0;
for(int i = 0; i < m_iKeysSize; i++)
score += querys.At(shift_query + i) * keys.At(shift_key + i);
scores[shift_score] = MathExp(score / MathSqrt(m_iKeysSize));
summs[query] += scores[shift_score];
}
}
//--- >@<0;87C5< Scores
for(int query = 0; query < m_iUnits; query++)
{
for(int key = 0; key < m_iUnits; key++)
scores[query * m_iUnits * m_iHeads + head * m_iUnits + key] /= summs[query];
}
}
if(!m_cScores.AssignArray(scores))
return false;
//--- KE>4 1;>:0 2=8<0=8O
if(ArrayResize(summs, total) < total)
return false;
if(ArrayInitialize(summs, 0) < total)
return false;
CBufferDouble *values = m_cValues.GetOutputs();
for(int head = 0; head < m_iHeads; head++)
{
for(int value = 0; value < m_iUnits; value++)
{
int shift_value = m_iKeysSize * (value * m_iHeads + head);
for(int pos = 0; pos < m_iKeysSize; pos++)
{
double val = values.At(shift_value + pos);
for(int query = 0; query < m_iUnits; query++)
summs[m_iKeysSize * (query * m_iHeads + head) + pos] += val * scores[m_iUnits * (query * m_iHeads + head) + value];
}
}
}
if(!m_cAttentionOut.GetOutputs().AssignArray(summs))
return false;
}
else // ;>: OpenCL
{
//--- !>740=85 1CD5@>2 40==KE
if(m_cQuerys.GetOutputs().GetIndex() < 0 && !m_cQuerys.GetOutputs().BufferCreate(m_cOpenCL))
return false;
if(m_cKeys.GetOutputs().GetIndex() < 0 && !m_cKeys.GetOutputs().BufferCreate(m_cOpenCL))
return false;
if(m_cValues.GetOutputs().GetIndex() < 0 && !m_cValues.GetOutputs().BufferCreate(m_cOpenCL))
return false;
int scores=(int)MathPow(m_iUnits,2)*m_iHeads;
if(m_cScores.Total()!=scores)
if(!m_cScores.BufferInit(scores,0))
return false;
if(m_cScores.GetIndex() < 0 && !m_cScores.BufferCreate(m_cOpenCL))
return false;
if(m_cAttentionOut.GetOutputs().GetIndex() < 0 && !m_cAttentionOut.GetOutputs().BufferCreate(m_cOpenCL))
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_keys, m_cKeys.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_outputs, m_cAttentionOut.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_querys, m_cQuerys.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_scores, m_cScores.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionFeedForward, def_attff_values, m_cValues.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionFeedForward, def_attff_key_size, m_iKeysSize))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionFeedForward, def_attff_window, m_iKeysSize))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionFeedForward, def_attff_mask, 0))
return false;
//--- >AB0=>2:0 :5@=5;0 2 >G5@54L 2K?>;=5=8O
int off_set[] = {0,0};
int NDRange[] = {m_iUnits,m_iHeads};
if(!m_cOpenCL.Execute(def_k_AttentionFeedForward, 2, off_set, NDRange))
return false;
//--- !G8BK20=885 @57C;LB0B>2 >?5@0F89
if(!m_cAttentionOut.GetOutputs().BufferRead())
return false;
if(!m_cScores.BufferRead())
return false;
m_cQuerys.GetOutputs().BufferFree();
m_cKeys.GetOutputs().BufferFree();
m_cValues.GetOutputs().BufferFree();
m_cScores.BufferFree();
prevLayer.GetOutputs().BufferFree();
}
//---
if(!m_cW0.FeedForward(m_cAttentionOut))
return false;
total = m_cW0.GetOutputs().GetData(summs, false);
if(total <= 0)
return false;
if(ArrayResize(summs, total) != total)
return false;
//--- !C<<8@C5< A 8AE>4=K<8 40==K<8 8 =>@<0;87C5<
double mean = 0;
CBufferDouble *prev = prevLayer.GetOutputs();
for(int i = 0; i < total; i++)
{
summs[i] += prev.At(i);
mean += summs[i];
}
mean /= total;
m_dStd[0] = MathStandardDeviation(summs);
for(int i = 0; i < total; i++)
summs[i] = (summs[i] - mean) / m_dStd[0];
if(!m_cW0.GetOutputs().AssignArray(summs))
return false;
//---
if(!m_cFF1.FeedForward(m_cW0))
return false;
if(!m_cFF2.FeedForward(m_cFF1))
return false;
//--- !C<<8@C5< A 2KE>4>< 2=8<0=8O 8 =>@<0;87C5<
mean = 0;
for(int i = 0; i < total; i++)
{
summs[i] += m_cOutputs.At(i);
mean += summs[i];
}
mean /= total;
m_dStd[1] = MathStandardDeviation(summs);
for(int i = 0; i < total; i++)
summs[i] = (summs[i] - mean) / m_dStd[1];
if(!m_cOutputs.AssignArray(summs))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 @0A?@545;5=8O 3@0485=B0 >H81:8 G5@57 A:@KBK9 A;>9 |
//+------------------------------------------------------------------+
bool CNeuronMHAttention::CalcHiddenGradient(CNeuronBase *prevLayer)
{
//--- @>25@O5< 0:BC0;L=>ABL 2A5E >1J5:B>2
if(CheckPointer(m_cOutputs) == POINTER_INVALID ||
CheckPointer(m_cGradients) == POINTER_INVALID ||
CheckPointer(m_cScores) == POINTER_INVALID ||
CheckPointer(m_cFF2) == POINTER_INVALID ||
CheckPointer(m_cQuerys) == POINTER_INVALID ||
CheckPointer(m_cKeys) == POINTER_INVALID ||
CheckPointer(m_cValues) == POINTER_INVALID ||
m_cOutputs.Total() != m_cGradients.Total())
return false;
//--- 0AHB018@>20=85 3@0485=B0 >H81:8
if(m_dStd[1]!=0 && m_cGradients.Scaling(1/m_dStd[1])<=0)
return false;
//--- @>2>48< 3@0485=B >H81:8 G5@57 1;>: Feed Forward
if(!m_cFF2.CalcHiddenGradient(m_cFF1))
return false;
if(!m_cFF1.CalcHiddenGradient(m_cW0))
return false;
CBufferDouble *attention_grad = m_cW0.GetGradients();
if(!attention_grad.SumArray(m_cGradients))
return false;
if(m_dStd[0]!=0 && attention_grad.Scaling(1/m_dStd[0])<=0)
return false;
//--- 0A?@545;5=8O 3@0;85=B0 >H81:8 ?> 3>;>20< 2=8<0=8O
if(!m_cW0.CalcHiddenGradient(m_cAttentionOut))
return false;
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
attention_grad = m_cAttentionOut.GetGradients();
if(CheckPointer(m_cOpenCL) == POINTER_INVALID)
{
int total = m_cValues.GetOutputs().Total();
double values[];
double gradients[];
if(ArrayResize(values, total) < total || attention_grad.GetData(gradients, false) < total)
return false;
if(ArrayInitialize(values, 0) < total)
return false;
//--- 0A?@545;5=85 3@0485=B0 =0 Values
for(int head = 0; head < m_iHeads; head++)
{
for(int value = 0; value < m_iUnits; value++)
{
for(int grad = 0; grad < m_iUnits; grad++)
{
double score = m_cScores.At(m_iUnits * (grad * m_iHeads + head) + value);
for(int i = 0; i < m_iKeysSize; i++)
values[m_iKeysSize * (value * m_iHeads + head) + i] += gradients[m_iKeysSize * (grad * m_iHeads + head) + i] * score;
}
}
}
if(!m_cValues.GetGradients().AssignArray(values))
return false;
//--- 0A?@545;5=85 3@0485=B0 =0 Querys 8 Keys
if(m_cValues.GetOutputs().GetData(values, false) <= 0)
return false;
double querys[], querys_grad[];
double keys[], keys_grad[];
int keys_total = m_iUnits * m_iKeysSize * m_iHeads;
if(m_cQuerys.GetOutputs().GetData(querys, false) < keys_total)
return false;
if(m_cKeys.GetOutputs().GetData(keys, false) < keys_total)
return false;
if(ArrayResize(querys_grad, keys_total) <= 0 || ArrayResize(keys_grad, keys_total) <= 0)
return false;
if(ArrayInitialize(querys_grad, 0) <= 0 || ArrayInitialize(keys_grad, 0) <= 0)
return false;
double score_grad[];
if(ArrayResize(score_grad, m_iUnits) <= 0)
return false;
for(int head = 0; head < m_iHeads; head++)
{
for(int q = 0; q < m_iUnits; q++)
{
if(ArrayInitialize(score_grad, 0) <= 0)
return false;
for(int k = 0; k < m_iUnits; k++)
{
for(int i = 0; i < m_iKeysSize; i++)
score_grad[k] += gradients[m_iKeysSize * (q * m_iHeads + head) + i] * values[m_iKeysSize * (k * m_iHeads + head) + i];
}
//---
int shift_grad = m_iKeysSize * (q * m_iHeads + head);
for(int k = 0; k < m_iUnits; k++)
{
int shift_key = m_iKeysSize * (k * m_iHeads + head);
double score = m_cScores.At(m_iUnits * (q * m_iHeads + head) + k);
double grad = 0;
for(int i = 0; i < m_iUnits; i++)
grad += m_cScores.At(m_iUnits * (q * m_iHeads + head) + i) * ((int)(i == k) - score) * score_grad[i];
grad /= MathSqrt(m_iKeysSize);
//---
for(int i = 0; i < m_iKeysSize; i++)
{
querys_grad[shift_grad + i] += grad * keys[shift_key + i];
keys_grad[shift_key + i ] += grad * querys[shift_grad + i];
}
}
}
}
if(!m_cQuerys.GetGradients().AssignArray(querys_grad) || !m_cKeys.GetGradients().AssignArray(keys_grad))
return false;
}
else // ;>: OpenCL
{
//--- !>740=85 1CD5@>2 40==KE
if(m_cValues.GetOutputs().GetIndex() < 0 && !m_cValues.GetOutputs().BufferCreate(m_cOpenCL))
return false;
if(m_cValues.GetGradients().GetIndex() < 0 && !m_cValues.GetGradients().BufferCreate(m_cOpenCL))
return false;
if(m_cScores.GetIndex() < 0 && !m_cScores.BufferCreate(m_cOpenCL))
return false;
if(m_cAttentionOut.GetGradients().GetIndex() < 0 && !m_cAttentionOut.GetGradients().BufferCreate(m_cOpenCL))
return false;
if(CheckPointer(m_cScoreGrad) == POINTER_INVALID)
{
m_cScoreGrad = new CBufferDouble();
if(CheckPointer(m_cScoreGrad) == POINTER_INVALID)
return false;
}
if(m_cScoreGrad.Total() != m_cScores.Total())
if(!m_cScoreGrad.BufferInit(m_cScores.Total(), 0))
return false;
if(m_cScoreGrad.GetIndex() < 0 && !m_cScoreGrad.BufferCreate(m_cOpenCL))
return false;
//---
if(CheckPointer(m_cScoreTemp) == POINTER_INVALID)
{
m_cScoreTemp = new CBufferDouble();
if(CheckPointer(m_cScoreTemp) == POINTER_INVALID)
return false;
}
if(m_cScoreTemp.Total() != m_cScores.Total())
if(!m_cScoreTemp.BufferInit(m_cScores.Total(), 0))
return false;
if(m_cScoreTemp.GetIndex() < 0 && !m_cScoreTemp.BufferCreate(m_cOpenCL))
return false;
//--- 5@540G0 ?0@0<5B@>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_outputs_grad, m_cAttentionOut.GetGradients().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_scores, m_cScores.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_scores_grad, m_cScoreGrad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_scores_temp, m_cScoreTemp.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_values, m_cValues.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionScoreGradients, def_attscr_values_grad, m_cValues.GetGradients().GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionScoreGradients, def_attscr_window, m_iKeysSize))
return false;
//--- >AB0=>2:0 :5@=5;0 2 >G5@54L 2K?>;=5=8O
int off_set[] = {0,0};
int NDRange[] = {m_iUnits,m_iHeads};
if(!m_cOpenCL.Execute(def_k_AttentionScoreGradients, 2, off_set, NDRange))
return false;
//--- 03@C7:0 @57C;LB0B>2
if(!m_cValues.GetGradients().BufferRead())
return false;
m_cValues.GetOutputs().BufferFree();
m_cScores.BufferFree();
m_cScoreTemp.BufferFree();
m_cAttentionOut.GetOutputs().BufferFree();
//--- !>740=85 1CD5@>2 40==KE
if(m_cQuerys.GetOutputs().GetIndex() < 0 && !m_cQuerys.GetOutputs().BufferCreate(m_cOpenCL))
return false;
if(m_cQuerys.GetGradients().GetIndex() < 0 && !m_cQuerys.GetGradients().BufferCreate(m_cOpenCL))
return false;
if(m_cKeys.GetOutputs().GetIndex() < 0 && !m_cKeys.GetOutputs().BufferCreate(m_cOpenCL))
return false;
if(m_cKeys.GetGradients().GetIndex() < 0 && !m_cKeys.GetGradients().BufferCreate(m_cOpenCL))
return false;
//--- 5@540G0 0@3C<5=B>2 :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_keys, m_cKeys.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_keys_grad, m_cKeys.GetGradients().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_querys, m_cQuerys.GetOutputs().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_querys_grad, m_cQuerys.GetGradients().GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_AttentionHiddenGradients, def_atthgr_scores_grad, m_cScoreGrad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_AttentionHiddenGradients, def_atthgr_key_size, m_iKeysSize))
return false;
//--- >AB0=>2:0 :5@=5;0 2 >G5@54L 2K?>;=5=8O
if(!m_cOpenCL.Execute(def_k_AttentionHiddenGradients, 2, off_set, NDRange))
return false;
//--- 03@C7:0 @57C;LB0B>2
if(!m_cQuerys.GetGradients().BufferRead() ||
!m_cKeys.GetGradients().BufferRead())
return false;
//---
m_cScoreGrad.BufferFree();
}
//--- 5@5=>A 3@0485=B0 >H81:8 =0 ?@54K4CI89 A;>9
attention_grad = new CBufferDouble();
if(CheckPointer(attention_grad) == POINTER_INVALID)
return false;
if(!attention_grad.AssignArray(m_cW0.GetGradients()))
{
delete attention_grad;
return false;
}
if(!m_cValues.CalcHiddenGradient(prevLayer))
{
delete attention_grad;
return false;
}
if(!attention_grad.SumArray(prevLayer.GetGradients()))
{
delete attention_grad;
return false;
}
if(!m_cQuerys.CalcHiddenGradient(prevLayer))
{
delete attention_grad;
return false;
}
if(!attention_grad.SumArray(prevLayer.GetGradients()))
{
delete attention_grad;
return false;
}
if(!m_cKeys.CalcHiddenGradient(prevLayer))
{
delete attention_grad;
return false;
}
if(!prevLayer.GetGradients().SumArray(attention_grad))
{
delete attention_grad;
return false;
}
delete attention_grad;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 @0A?@545;5=8O 3@0485=B0 >H81:8 4> <0B@8F 25A>2 |
//+------------------------------------------------------------------+
bool CNeuronMHAttention::CalcDeltaWeights(CNeuronBase *prevLayer)
{
//--- ;>: :>=B@>;59
if(CheckPointer(m_cFF2)==POINTER_INVALID)
return false;
//--- K7K205< 0=0;>38G=K9 <5B>4 4;O 2A5E 2=CB@5==8E A;>Q2
if(!m_cFF2.CalcDeltaWeights(m_cFF1))
return false;
if(!m_cFF1.CalcDeltaWeights(m_cW0))
return false;
if(!m_cW0.CalcDeltaWeights(m_cAttentionOut))
return false;
if(CheckPointer(m_cQuerys)==POINTER_INVALID)
return false;
if(!m_cQuerys.CalcDeltaWeights(prevLayer))
return false;
if(CheckPointer(m_cKeys)==POINTER_INVALID)
return false;
if(!m_cKeys.CalcDeltaWeights(prevLayer))
return false;
if(CheckPointer(m_cValues)==POINTER_INVALID)
return false;
if(!m_cValues.CalcDeltaWeights(prevLayer))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 >1=>2;5=8O <0B@8F 25A>2KE :>MDD8F85=B>2 |
//+------------------------------------------------------------------+
bool CNeuronMHAttention::UpdateWeights(int batch_size, double learningRate, double &Beta[], double &Lambda[])
{
//--- K7>2 <5B>40 @>48B5;LA:>3> :;0AA0
if(!CNeuronAttention::UpdateWeights(batch_size, learningRate, Beta, Lambda))
return false;
//--- ;>: :>=B@>;59
if(CheckPointer(m_cW0)==POINTER_INVALID)
return false;
//--- K7K205< 0=0;>38G=K9 <5B>4 4;O 2A5E 2=CB@5==8E A;>Q2
if(!m_cW0.UpdateWeights(batch_size, learningRate, Beta, Lambda))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 A>E@0=5=8O M;5<5=B>2 :;0A0 2 D09; |
//+------------------------------------------------------------------+
bool CNeuronMHAttention::Save(const int file_handle)
{
//--- K7>2 <5B>40 @>48B5;LA:>3> :;0AA0
if(!CNeuronAttention::Save(file_handle))
return false;
//--- !>E@0=O5< :>=AB0=BK
if(FileWriteInteger(file_handle, m_iHeads) <= 0)
return false;
//--- K7K205< 0=0;>38G=K9 <5B>4 4;O 2A5E 2=CB@5==8E A;>Q2
if(CheckPointer(m_cW0)==POINTER_INVALID)
return false;
if(!m_cW0.Save(file_handle))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 2>AAB0=>2;5=8O :;0AA0 87 A>E@0=Q==KE 40==KE |
//+------------------------------------------------------------------+
bool CNeuronMHAttention::Load(const int file_handle)
{
//--- K7>2 <5B>40 @>48B5;LA:>3> :;0AA0
if(!CNeuronAttention::Load(file_handle))
return false;
//--- 03@C605< :>=AB0=BK
m_iHeads = FileReadInteger(file_handle);
//--- K7K205< 0=0;>38G=K9 <5B>4 4;O 2A5E 2=CB@5==8E A;>Q2
if(CheckPointer(m_cW0) == POINTER_INVALID)
{
m_cW0 = new CNeuronConv();
if(CheckPointer(m_cW0) == POINTER_INVALID)
return false;
}
if(FileReadInteger(file_handle)!=defNeuronConv || !m_cW0.Load(file_handle))
return false;
//---
return true;
}
//+------------------------------------------------------------------+