NeuroNetworksBook/Include/realization/neuronlstm.mqh

861 lines
61 KiB
MQL5
Raw Permalink Normal View History

2025-05-30 16:12:34 +02:00
<EFBFBD><EFBFBD>//+------------------------------------------------------------------+
//| NeuronLSTM.mqh |
//| Copyright 2021, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2021, MetaQuotes Ltd."
#property link "https://www.mql5.com"
//+------------------------------------------------------------------+
//| >4:;NG05< 181;8>B5:8 |
//+------------------------------------------------------------------+
#include "neuronbase.mqh"
#include <Arrays\ArrayObj.mqh>
//+------------------------------------------------------------------+
//| Class CNeuronLSTM |
//| 07=0G5=85: ;0AA >@30=870F88 @5:C@@5=B=>3> LSTM 1;>:0 |
//+------------------------------------------------------------------+
class CNeuronLSTM : public CNeuronBase
{
protected:
CNeuronBase *m_cForgetGate;
CNeuronBase *m_cInputGate;
CNeuronBase *m_cNewContent;
CNeuronBase *m_cOutputGate;
CArrayObj *m_cMemorys;
CArrayObj *m_cHiddenStates;
CArrayObj *m_cInputs;
CArrayObj *m_cForgetGateOuts;
CArrayObj *m_cInputGateOuts;
CArrayObj *m_cNewContentOuts;
CArrayObj *m_cOutputGateOuts;
CBufferDouble *m_cInputGradient;
int m_iDepth;
void ClearBuffer(CArrayObj *buffer);
bool InsertBuffer(CArrayObj *&array, CBufferDouble *element, bool create_new = true);
CBufferDouble *CreateBuffer(CArrayObj *&array);
public:
CNeuronLSTM(void);
~CNeuronLSTM(void);
//---
virtual bool Init(CLayerDescription *desc);
virtual bool SetOpenCL(CMyOpenCL *opencl);
virtual bool FeedForward(CNeuronBase *prevLayer);
virtual bool CalcHiddenGradient(CNeuronBase *prevLayer);
virtual bool CalcDeltaWeights(CNeuronBase *prevLayer) { return true; }
virtual bool UpdateWeights(int batch_size, double learningRate,
double &Beta[], double &Lambda[]);
//---
virtual int GetDepth(void) const { return m_iDepth; }
//--- 5B>4K @01>BK A D09;0<8
virtual bool Save(const int file_handle);
virtual bool Load(const int file_handle);
//--- 5B>4 845=B8D8:0F88 >1J5:B0
virtual int Type(void) const { return(defNeuronLSTM); }
};
//+------------------------------------------------------------------+
//| >=AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronLSTM::CNeuronLSTM(void) : m_iDepth(2)
{
m_cForgetGate = new CNeuronBase();
m_cInputGate = new CNeuronBase();
m_cNewContent = new CNeuronBase();
m_cOutputGate = new CNeuronBase();
m_cMemorys = new CArrayObj();
m_cHiddenStates = new CArrayObj();
m_cInputs = new CArrayObj();
m_cForgetGateOuts = new CArrayObj();
m_cInputGateOuts = new CArrayObj();
m_cNewContentOuts = new CArrayObj();
m_cOutputGateOuts = new CArrayObj();
m_cInputGradient = new CBufferDouble();
}
//+------------------------------------------------------------------+
//| 5AB@C:B>@ :;0AA0 |
//+------------------------------------------------------------------+
CNeuronLSTM::~CNeuronLSTM(void)
{
if(m_cForgetGate)
delete m_cForgetGate;
if(m_cInputGate)
delete m_cInputGate;
if(m_cNewContent)
delete m_cNewContent;
if(m_cOutputGate)
delete m_cOutputGate;
if(m_cMemorys)
delete m_cMemorys;
if(m_cHiddenStates)
delete m_cHiddenStates;
if(m_cInputs)
delete m_cInputs;
if(m_cForgetGateOuts)
delete m_cForgetGateOuts;
if(m_cInputGateOuts)
delete m_cInputGateOuts;
if(m_cNewContentOuts)
delete m_cNewContentOuts;
if(m_cOutputGateOuts)
delete m_cOutputGateOuts;
if(m_cInputGradient)
delete m_cInputGradient;
}
//+------------------------------------------------------------------+
//| 5B>4 8=8F80;870F88 :;0AA0 |
//+------------------------------------------------------------------+
bool CNeuronLSTM::Init(CLayerDescription *desc)
{
//--- ;>: :>=B@>;59
if(!desc || desc.type != Type() || desc.count <= 0 || desc.window == 0)
return false;
//--- !>740Q< >?8A0=85 4;O 2=CB@5==8E =59@>==KE A;>Q2
CLayerDescription *temp = new CLayerDescription();
if(!temp)
return false;
temp.type = defNeuronBase;
temp.window = desc.window + desc.count;
temp.count = desc.count;
temp.activation = ACT_SIGMOID;
temp.activation_params[0] = 1;
temp.activation_params[1] = 0;
temp.optimization = desc.optimization;
//--- K7K205< <5B>4 8=8F80;870F88 @>48B5;LA:>3> :;0AA0
desc.window = 0;
if(!CNeuronBase::Init(desc))
return false;
m_iDepth = (int)fmax(desc.window_out, 2);
//--- =8F80;878@C5< ForgetGate
if(!m_cForgetGate)
{
m_cForgetGate = new CNeuronBase();
if(!m_cForgetGate)
return false;
}
if(!m_cForgetGate.Init(temp))
return false;
//--- =8F80;878@C5< InputGate
if(CheckPointer(m_cInputGate) == POINTER_INVALID)
{
m_cInputGate = new CNeuronBase();
if(CheckPointer(m_cInputGate) == POINTER_INVALID)
return false;
}
if(!m_cInputGate.Init(temp))
return false;
//--- =8F80;878@C5< OutputGate
if(CheckPointer(m_cOutputGate) == POINTER_INVALID)
{
m_cOutputGate = new CNeuronBase();
if(CheckPointer(m_cOutputGate) == POINTER_INVALID)
return false;
}
if(!m_cOutputGate.Init(temp))
return false;
//--- =8F80;878@C5< NewContent
if(CheckPointer(m_cNewContent) == POINTER_INVALID)
{
m_cNewContent = new CNeuronBase();
if(CheckPointer(m_cNewContent) == POINTER_INVALID)
return false;
}
temp.activation = ACT_TANH;
if(!m_cNewContent.Init(temp))
return false;
//--- =8F80;878@C5< 1CD5@ InputGradient
if(CheckPointer(m_cInputGradient) == POINTER_INVALID)
{
m_cInputGradient = new CBufferDouble();
if(CheckPointer(m_cInputGradient) == POINTER_INVALID)
return false;
}
if(!m_cInputGradient.BufferInit(temp.window, 0))
return false;
delete temp;
//--- =8F80;878@C5< Memory
CBufferDouble *buffer = CreateBuffer(m_cMemorys);
if(CheckPointer(buffer) == POINTER_INVALID)
return false;
if(m_cMemorys.Total() > 0)
{
if(!buffer.BufferInit(desc.count, 0))
{
delete buffer;
return false;
}
m_cMemorys.Clear();
}
if(!m_cMemorys.Add(buffer))
{
delete buffer;
return false;
}
//--- =8F80;878@C5< HiddenStates
buffer = CreateBuffer(m_cHiddenStates);
if(CheckPointer(buffer) == POINTER_INVALID)
return false;
if(m_cHiddenStates.Total() > 0)
{
if(!buffer.BufferInit(desc.count, 0))
return false;
m_cHiddenStates.Clear();
}
if(!m_cHiddenStates.Add(buffer))
return false;
//---
SetOpenCL(m_cOpenCL);
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 ?5@540G8 C:070B5;O =0 >1J5:B OpenCL 4> 2A5E |
//| 2=CB@5==8E >1J5:B>2 |
//+------------------------------------------------------------------+
bool CNeuronLSTM::SetOpenCL(CMyOpenCL *opencl)
{
//--- K7>2 <5B>40 @>48B5;LA:>3> :;0AA0
CNeuronBase::SetOpenCL(opencl);
//--- K7K205< 0=0;>38G=K9 <5B>4 4;O 2A5E 2=CB@5==8E A;>Q2
if(!m_cForgetGate.SetOpenCL(m_cOpenCL) ||
m_cInputGate.SetOpenCL(m_cOpenCL) ||
m_cOutputGate.SetOpenCL(m_cOpenCL) ||
m_cNewContent.SetOpenCL(m_cOpenCL) ||
m_cInputGradient.BufferCreate(m_cOpenCL))
delete m_cOpenCL;
//---
return(!!m_cOpenCL);
}
//+------------------------------------------------------------------+
//| 5B>4 C40;5=8O 87 AB5:0 <7;8H=8E 40==KE |
//+------------------------------------------------------------------+
void CNeuronLSTM::ClearBuffer(CArrayObj *buffer)
{
if(CheckPointer(buffer) == POINTER_INVALID)
return;
int total = buffer.Total();
if(total > m_iDepth + 1)
buffer.DeleteRange(m_iDepth + 1, total);
}
//+------------------------------------------------------------------+
//| 5B>4 ?@O<>3> ?@>E>40 |
//+------------------------------------------------------------------+
bool CNeuronLSTM::FeedForward(CNeuronBase *prevLayer)
{
//--- @>25@O5< 0:BC0;L=>ABL 2A5E >1J5:B>2
if(CheckPointer(prevLayer) == POINTER_INVALID ||
CheckPointer(prevLayer.GetOutputs()) == POINTER_INVALID ||
CheckPointer(m_cOutputs) == POINTER_INVALID ||
CheckPointer(m_cForgetGate) == POINTER_INVALID ||
CheckPointer(m_cInputGate) == POINTER_INVALID ||
CheckPointer(m_cOutputGate) == POINTER_INVALID ||
CheckPointer(m_cNewContent) == POINTER_INVALID)
return false;
//--- >43>B028205< 703>B>2:8 4;O =>2KE 1CD5@>2 ?0<OB8 8 A:@KB>3> A>AB>O=8O
CBufferDouble *memory = CreateBuffer(m_cMemorys);
if(CheckPointer(memory) == POINTER_INVALID)
return false;
CBufferDouble *hidden = CreateBuffer(m_cHiddenStates);
if(CheckPointer(hidden) == POINTER_INVALID)
{
delete memory;
return false;
}
//--- ">;L:> 4;O ?@>25@:8 3@0485=B0
//memory.m_mMatrix.Fill(0);
//hidden.m_mMatrix.Fill;
//--- !>740Q< 1CD5@ 8AE>4=KE 40==KE
if(!m_cInputs)
{
m_cInputs = new CArrayObj();
if(!m_cInputs)
{
delete memory;
delete hidden;
return false;
}
}
CNeuronBase *inputs = new CNeuronBase();
if(!inputs)
{
delete memory;
delete hidden;
return false;
}
CLayerDescription *desc = new CLayerDescription();
if(!desc)
{
delete inputs;
delete memory;
delete hidden;
return false;
}
desc.type = defNeuronBase;
desc.count = (int)(prevLayer.GetOutputs().Total() + m_cOutputs.Total());
desc.window = 0;
if(!inputs.Init(desc))
{
delete inputs;
delete memory;
delete hidden;
return false;
}
delete desc;
CBufferDouble *inputs_buffer = inputs.GetOutputs();
if(!inputs_buffer)
{
delete inputs;
delete memory;
delete hidden;
return false;
}
inputs_buffer.m_mMatrix = prevLayer.GetOutputs().m_mMatrix;
ulong shift = inputs_buffer.m_mMatrix.Cols();
ulong cols = hidden.m_mMatrix.Cols();
if(!inputs_buffer.m_mMatrix.Resize(inputs_buffer.m_mMatrix.Rows(), shift + cols))
{
delete inputs;
delete memory;
delete hidden;
return false;
}
for(ulong c = 0; c < cols; c++)
if(!inputs_buffer.m_mMatrix.Col(hidden.m_mMatrix.Col(c), shift + c))
{
delete inputs;
delete memory;
delete hidden;
return false;
}
//--- 5;05< ?@O<>9 ?@>E>4 2=CB@5==8E =59@>==KE A;>Q2
if(!m_cForgetGate.FeedForward(inputs))
{
delete inputs;
delete memory;
delete hidden;
return false;
}
if(!m_cInputGate.FeedForward(inputs))
{
delete inputs;
delete memory;
delete hidden;
return false;
}
if(!m_cOutputGate.FeedForward(inputs))
{
delete inputs;
delete memory;
delete hidden;
return false;
}
if(!m_cNewContent.FeedForward(inputs))
{
delete inputs;
delete memory;
delete hidden;
return false;
}
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
CBufferDouble *fg = m_cForgetGate.GetOutputs();
CBufferDouble *ig = m_cInputGate.GetOutputs();
CBufferDouble *og = m_cOutputGate.GetOutputs();
CBufferDouble *nc = m_cNewContent.GetOutputs();
if(!m_cOpenCL)
{
memory.m_mMatrix *= fg.m_mMatrix;
memory.m_mMatrix += ig.m_mMatrix * nc.m_mMatrix;
ulong total = memory.Total();
for(ulong i = 0; i < total; i++)
if(!hidden.m_mMatrix.Flat(i, MathTanh(memory.m_mMatrix.Flat(i)) * og.m_mMatrix.Flat(i)))
{
delete inputs;
delete memory;
delete hidden;
return false;
}
}
else
{
//--- !>740Q< 1CD5@K
if(!fg.BufferCreate(m_cOpenCL))
return false;
if(!ig.BufferCreate(m_cOpenCL))
return false;
if(!og.BufferCreate(m_cOpenCL))
return false;
if(!nc.BufferCreate(m_cOpenCL))
return false;
if(!memory.BufferCreate(m_cOpenCL))
return false;
if(!hidden.BufferCreate(m_cOpenCL))
return false;
//--- 5@540Q< ?0@0<5B@K :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMFeedForward, def_lstmff_forgetgate, fg.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMFeedForward, def_lstmff_inputgate, ig.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMFeedForward, def_lstmff_newcontent, nc.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMFeedForward, def_lstmff_outputgate, og.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMFeedForward, def_lstmff_memory, memory.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMFeedForward, def_lstmff_hiddenstate, hidden.GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_LSTMFeedForward, def_lstmff_outputs_total, m_cOutputs.Total()))
return false;
//--- 0?CA: :5@=5;0
int NDRange[] = {(int)(m_cOutputs.Total() + 3) / 4};
int off_set[] = {0};
if(!m_cOpenCL.Execute(def_k_LSTMFeedForward, 1, off_set, NDRange))
return false;
//--- >;CG05< @57C;LB0BK
if(!memory.BufferRead())
return false;
if(!hidden.BufferRead())
return false;
}
//--- >?8@C5< A:@KB>5 A>AB>O=85 2 1CD5@ @57C;LB0B>2 =59@>==>3> A;>O
m_cOutputs.m_mMatrix = hidden.m_mMatrix;
if(!!m_cOpenCL && !m_cOutputs.BufferWrite())
{
delete inputs;
delete memory;
delete hidden;
return false;
}
//--- !>E@0=8< B5:CI55 A>AB>O=85
if(!m_cInputs.Insert(inputs, 0))
{
delete inputs;
delete memory;
delete hidden;
return false;
}
ClearBuffer(m_cInputs);
if(!InsertBuffer(m_cForgetGateOuts, m_cForgetGate.GetOutputs()))
{
delete memory;
delete hidden;
return false;
}
if(!InsertBuffer(m_cInputGateOuts, m_cInputGate.GetOutputs()))
{
delete memory;
delete hidden;
return false;
}
if(!InsertBuffer(m_cOutputGateOuts, m_cOutputGate.GetOutputs()))
{
delete memory;
delete hidden;
return false;
}
if(!InsertBuffer(m_cNewContentOuts, m_cNewContent.GetOutputs()))
{
delete memory;
delete hidden;
return false;
}
if(!InsertBuffer(m_cMemorys, memory, false))
{
delete hidden;
return false;
}
if(!InsertBuffer(m_cHiddenStates, hidden, false))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 4>102;5=8O 40==KE 2 AB5: |
//+------------------------------------------------------------------+
bool CNeuronLSTM::InsertBuffer(CArrayObj *&array, CBufferDouble *element, bool create_new = true)
{
//--- ;>: :>=B@>;59
if(CheckPointer(element) == POINTER_INVALID)
return false;
if(CheckPointer(array) == POINTER_INVALID)
{
array = new CArrayObj();
if(CheckPointer(array) == POINTER_INVALID)
return false;
}
//---
if(create_new)
{
CBufferDouble *buffer = new CBufferDouble();
if(CheckPointer(buffer) == POINTER_INVALID)
return false;
buffer.m_mMatrix = element.m_mMatrix;
if(!array.Insert(buffer, 0))
{
delete buffer;
return false;
}
}
else
{
if(!array.Insert(element, 0))
{
delete element;
return false;
}
}
//--- #40;8< 87 1CD5@0 87;8H=NN 8AB>@8N
ClearBuffer(array);
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 A>740=8O =>2>3> 1CD5@0 40==KE |
//+------------------------------------------------------------------+
CBufferDouble *CNeuronLSTM::CreateBuffer(CArrayObj *&array)
{
if(CheckPointer(array) == POINTER_INVALID)
{
array = new CArrayObj();
if(CheckPointer(array) == POINTER_INVALID)
return NULL;
}
CBufferDouble *buffer = new CBufferDouble();
if(CheckPointer(buffer) == POINTER_INVALID)
return NULL;
if(!buffer.BufferInit(m_cOutputs.Rows(), m_cOutputs.Cols(), 0))
{
delete buffer;
return NULL;
}
if(m_cOpenCL)
{
if(!buffer.BufferCreate(m_cOpenCL))
delete buffer;
}
//---
return buffer;
}
//+------------------------------------------------------------------+
//| 5B>4 @0A?@545;5=8O 3@0485=B0 >H81:8 G5@57 A:@KBK9 A;>9 |
//+------------------------------------------------------------------+
bool CNeuronLSTM::CalcHiddenGradient(CNeuronBase *prevLayer)
{
//--- @>25@O5< 0:BC0;L=>ABL 2A5E >1J5:B>2
if(CheckPointer(prevLayer) == POINTER_INVALID ||
CheckPointer(prevLayer.GetGradients()) == POINTER_INVALID ||
CheckPointer(m_cGradients) == POINTER_INVALID ||
CheckPointer(m_cForgetGate) == POINTER_INVALID ||
CheckPointer(m_cForgetGateOuts) == POINTER_INVALID ||
CheckPointer(m_cInputGate) == POINTER_INVALID ||
CheckPointer(m_cInputGateOuts) == POINTER_INVALID ||
CheckPointer(m_cOutputGate) == POINTER_INVALID ||
CheckPointer(m_cOutputGateOuts) == POINTER_INVALID ||
CheckPointer(m_cNewContent) == POINTER_INVALID ||
CheckPointer(m_cNewContentOuts) == POINTER_INVALID)
return false;
//--- @>25@O5< =0;8G85 40==KE ?@O<>3> ?@>E>40
int total = (int)fmin(m_cMemorys.Total(), m_cHiddenStates.Total()) - 1;
if(total <= 0)
return false;
//--- 5;05< C:70B5;8 =0 1CD5@K 3@0485=B>2 8 @57C;LB0B>2 2=CB@5==8E A;>Q2
CBufferDouble *fg_grad = m_cForgetGate.GetGradients();
if(CheckPointer(fg_grad) == POINTER_INVALID)
return false;
CBufferDouble *fg_out = m_cForgetGate.GetOutputs();
if(CheckPointer(fg_out) == POINTER_INVALID)
return false;
CBufferDouble *ig_grad = m_cInputGate.GetGradients();
if(CheckPointer(ig_grad) == POINTER_INVALID)
return false;
CBufferDouble *ig_out = m_cInputGate.GetOutputs();
if(CheckPointer(ig_out) == POINTER_INVALID)
return false;
CBufferDouble *og_grad = m_cOutputGate.GetGradients();
if(CheckPointer(og_grad) == POINTER_INVALID)
return false;
CBufferDouble *og_out = m_cOutputGate.GetOutputs();
if(CheckPointer(og_out) == POINTER_INVALID)
return false;
CBufferDouble *nc_grad = m_cNewContent.GetGradients();
if(CheckPointer(nc_grad) == POINTER_INVALID)
return false;
CBufferDouble *nc_out = m_cNewContent.GetOutputs();
if(CheckPointer(nc_out) == POINTER_INVALID)
return false;
//---
uint out_total = m_cOutputs.Total();
//--- &8:; ?5@51>@0 =0:>?;5==>9 8AB>@88
for(int i = 0; i < total; i++)
{
//--- >;CG05< C:070B5;8 =0 1CD5@K 87 AB5:0
CBufferDouble *fg = m_cForgetGateOuts.At(i);
if(CheckPointer(fg) == POINTER_INVALID)
return false;
CBufferDouble *ig = m_cInputGateOuts.At(i);
if(CheckPointer(ig) == POINTER_INVALID)
return false;
CBufferDouble *og = m_cOutputGateOuts.At(i);
if(CheckPointer(og) == POINTER_INVALID)
return false;
CBufferDouble *nc = m_cNewContentOuts.At(i);
if(CheckPointer(nc) == POINTER_INVALID)
return false;
CBufferDouble *memory = m_cMemorys.At(i + 1);
if(CheckPointer(memory) == POINTER_INVALID)
return false;
CBufferDouble *hidden = m_cHiddenStates.At(i);
if(CheckPointer(hidden) == POINTER_INVALID)
return false;
CNeuronBase *inputs = m_cInputs.At(i);
if(CheckPointer(inputs) == POINTER_INVALID)
return false;
//--- 0725B2;5=85 0;3>@8B<0 ?> 2KG8A;8B5;L=><C CAB@>9AB2C
if(CheckPointer(m_cOpenCL) == POINTER_INVALID)
{
//--- >AG8B05< 3@0485=B =0 2KE>45 :064>3> 2=CB@5==53> A;>O
MATRIX m = hidden.m_mMatrix / (og.m_mMatrix + 1e-8);
//--- OutputGate 3@0485=B
MATRIX grad = m_cGradients.m_mMatrix;
og_grad.m_mMatrix = grad * m;
//--- @0485=B ?0<OB8
grad *= og.m_mMatrix;
//--- !:>@@5:B8@C5< 3@0485=B =0 ?@>872>4=CN
grad *= 1 - m.Power(2);
//--- InputGate 3@0485=B
ig_grad.m_mMatrix = grad * nc.m_mMatrix;
//--- NewContent 3@0485=B
nc_grad.m_mMatrix = grad * ig.m_mMatrix;
//--- ForgetGates 3@0485=B
fg_grad.m_mMatrix = grad * memory.m_mMatrix;
}
else
{
//--- !>740Q< 1CD5@K
if(hidden.GetIndex() < 0)
return false;
if(m_cGradients.GetIndex() < 0)
return false;
if(ig.GetIndex() < 0)
return false;
if(og.GetIndex() < 0)
return false;
if(nc.GetIndex() < 0)
return false;
if(memory.GetIndex() < 0)
return false;
if(fg_grad.GetIndex() < 0)
return false;
if(ig_grad.GetIndex() < 0)
return false;
if(og_grad.GetIndex() < 0)
return false;
if(nc_grad.GetIndex() < 0)
return false;
//--- 5@540Q< ?0@0<5B@K :5@=5;C
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_fg_gradients, fg_grad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_gradients, m_cGradients.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_ig_gradients, ig_grad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_inputgate, ig.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_memory, memory.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_nc_gradients, nc_grad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_newcontent, nc.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_og_gradients, og_grad.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_outputgate, og.GetIndex()))
return false;
if(!m_cOpenCL.SetArgumentBuffer(def_k_LSTMHiddenGradients, def_lstmhgr_outputs, hidden.GetIndex()))
return false;
if(!m_cOpenCL.SetArgument(def_k_LSTMHiddenGradients, def_lstmhgr_outputs_total, m_cOutputs.Total()))
return false;
//--- 0?CA: :5@=5;0
int NDRange[] = { (int)(m_cOutputs.Total() + 3) / 4 };
int off_set[] = {0};
if(!m_cOpenCL.Execute(def_k_LSTMHiddenGradients, 1, off_set, NDRange))
return false;
//--- >;CG05< @57C;LB0BK
if(!fg_grad.BufferRead())
return false;
if(!ig_grad.BufferRead())
return false;
if(!og_grad.BufferRead())
return false;
if(!nc_grad.BufferRead())
return false;
}
//--- !:>?8@C5< A>>B25BA2CNI85 8AB>@8G5A:85 40==K5 2 1CD5@0 2=CB@5==8E A;>Q2
fg_out.m_mMatrix = fg.m_mMatrix;
ig_out.m_mMatrix = ig.m_mMatrix;
og_out.m_mMatrix = og.m_mMatrix;
nc_out.m_mMatrix = nc.m_mMatrix;
//--- @>254Q< 3@0485=B G5@57 2=CB@5==85 A;>8
if(!m_cForgetGate.CalcHiddenGradient(inputs))
return false;
if(!m_cInputGradient)
{
m_cInputGradient = new CBufferDouble();
if(CheckPointer(m_cInputGradient) == POINTER_INVALID)
return false;
}
m_cInputGradient.m_mMatrix = inputs.GetGradients().m_mMatrix;
if(!m_cInputGate.CalcHiddenGradient(inputs))
return false;
if(!m_cInputGradient.SumArray(inputs.GetGradients()))
return false;
if(!m_cOutputGate.CalcHiddenGradient(inputs))
return false;
if(!m_cInputGradient.SumArray(inputs.GetGradients()))
return false;
if(!m_cNewContent.CalcHiddenGradient(inputs))
return false;
if(!inputs.GetGradients().SumArray(m_cInputGradient))
return false;
//--- !?@>MF8@C5< 3@0485=B =0 <0B@8FK 25A>2 2=CB@5==8E A;>Q2
if(!m_cForgetGate.CalcDeltaWeights(inputs))
return false;
if(!m_cInputGate.CalcDeltaWeights(inputs))
return false;
if(!m_cOutputGate.CalcDeltaWeights(inputs))
return false;
if(!m_cNewContent.CalcDeltaWeights(inputs))
return false;
//--- A;8 ?>AG8B0= 3@0485=B B5:CI53> A>AB>O=8O, B> ?5@540Q< =0 ?@54K4CI89 A;>9
ulong split[] = {prevLayer.GetGradients().Cols()};
MATRIX m[];
if(!inputs.GetGradients().m_mMatrix.Vsplit(split, m))
return false;
if(i == 0)
{
CBufferDouble *prevLayer_grad = prevLayer.GetGradients();
prevLayer_grad.m_mMatrix = m[0];
if(m_cOpenCL && !prevLayer_grad.BufferWrite())
return false;
}
//--- 0?8H5< 3@0485=B A:@KB>3> A>AB>O=8O 2 1CD5@ 3@0485=B>2 4;O =>2>9 8B5@0F88
m_cGradients.m_mMatrix = m[1];
if(m_cOpenCL && !m_cGradients.BufferWrite())
return false;
}
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 >1=>2;5=8O <0B@8F 25A>2KE :>MDD8F85=B>2 |
//+------------------------------------------------------------------+
bool CNeuronLSTM::UpdateWeights(int batch_size, double learningRate, double &Beta[], double &Lambda[])
{
//--- @>25@O5< A>AB>O=85 >1J5:B>2
if(CheckPointer(m_cForgetGate) == POINTER_INVALID ||
CheckPointer(m_cInputGate) == POINTER_INVALID ||
CheckPointer(m_cOutputGate) == POINTER_INVALID ||
CheckPointer(m_cNewContent) == POINTER_INVALID ||
m_iDepth <= 0)
return false;
int batch = batch_size * m_iDepth;
//--- 1=>2;O5< <0B@8FK 25A>2 2=CB@5==8E A;>Q2
if(!m_cForgetGate.UpdateWeights(batch, learningRate, Beta, Lambda))
return false;
if(!m_cInputGate.UpdateWeights(batch, learningRate, Beta, Lambda))
return false;
if(!m_cOutputGate.UpdateWeights(batch, learningRate, Beta, Lambda))
return false;
if(!m_cNewContent.UpdateWeights(batch, learningRate, Beta, Lambda))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 A>E@0=5=88O M;5<5=B>2 :;0AA0 2 D09; |
//+------------------------------------------------------------------+
bool CNeuronLSTM::Save(const int file_handle)
{
//--- K7>2 <5B>40 @>48B5;LA:>3> :;0AA0
if(!CNeuronBase::Save(file_handle))
return false;
//--- !>7@0=O5< :>=AB0=BK
if(FileWriteInteger(file_handle, m_iDepth) <= 0)
return false;
//--- K7K205< 0=0;>38G=K9 <5B>4 4;O 2A5E 2=CB@5==8E A;>Q2
if(!m_cForgetGate.Save(file_handle))
return false;
if(!m_cInputGate.Save(file_handle))
return false;
if(!m_cOutputGate.Save(file_handle))
return false;
if(!m_cNewContent.Save(file_handle))
return false;
//---
return true;
}
//+------------------------------------------------------------------+
//| 5B>4 2>AAB0=>2;5=8O :;0AA0 87 D09;0 |
//+------------------------------------------------------------------+
bool CNeuronLSTM::Load(const int file_handle)
{
//--- K7>2 <5B>40 @>48B5;LA:>3> :;0AA0
if(!CNeuronBase::Load(file_handle))
return false;
//--- !G8BK205< :>=AB0=BK
m_iDepth = FileReadInteger(file_handle);
//--- K7K205< 0=0;>38G=K9 <5B>4 4;O 2A5E 2=CB@5==8E A;>Q2
if(FileReadInteger(file_handle) != defNeuronBase || !m_cForgetGate.Load(file_handle))
return false;
if(FileReadInteger(file_handle) != defNeuronBase || !m_cInputGate.Load(file_handle))
return false;
if(FileReadInteger(file_handle) != defNeuronBase || !m_cOutputGate.Load(file_handle))
return false;
if(FileReadInteger(file_handle) != defNeuronBase || !m_cNewContent.Load(file_handle))
return false;
//--- =8F80;878@C5< 1CD5@ InputGradient
if(CheckPointer(m_cInputGradient) == POINTER_INVALID)
{
m_cInputGradient = new CBufferDouble();
if(CheckPointer(m_cInputGradient) == POINTER_INVALID)
return false;
}
//--- =8F80;878@C5< Memory
CBufferDouble *buffer = CreateBuffer(m_cMemorys);
if(CheckPointer(buffer) == POINTER_INVALID)
return false;
if(m_cMemorys.Total() > 0)
{
if(!buffer.BufferInit(m_cOutputs.Total(), 0))
return false;
m_cMemorys.Clear();
}
if(!m_cMemorys.Add(buffer))
return false;
//--- =8F80;878@C5< HiddenStates
buffer = CreateBuffer(m_cHiddenStates);
if(CheckPointer(buffer) == POINTER_INVALID)
return false;
if(m_cHiddenStates.Total() > 0)
{
if(!buffer.BufferInit(m_cOutputs.Total(), 0))
return false;
m_cHiddenStates.Clear();
}
if(!m_cHiddenStates.Add(buffer))
return false;
//--- G8I05< >AB0;L=K5 AB5:8
if(CheckPointer(m_cInputs) != POINTER_INVALID)
m_cInputs.Clear();
if(CheckPointer(m_cForgetGateOuts) != POINTER_INVALID)
m_cForgetGateOuts.Clear();
if(CheckPointer(m_cInputGateOuts) != POINTER_INVALID)
m_cInputGateOuts.Clear();
if(CheckPointer(m_cNewContentOuts) != POINTER_INVALID)
m_cNewContentOuts.Clear();
if(CheckPointer(m_cOutputGateOuts) != POINTER_INVALID)
m_cOutputGateOuts.Clear();
//---
return true;
}
//+------------------------------------------------------------------+