//+------------------------------------------------------------------+ //| Warrior_EA | //| AnimateDread | //| | //+------------------------------------------------------------------+ #include "..\Expert\ExpertSignalAIBase.mqh" // wizard description start //+------------------------------------------------------------------+ //| Description of the class | //| Title=Signals of indicator 'LSTM AI' | //| Type=SignalAdvanced | //| Name=LSTM AI | //| ShortName=LSTM | //| Class=CSignalLSTM | //| Page=signal_lstm | //+------------------------------------------------------------------+ // wizard description end //+------------------------------------------------------------------+ //| Class CSignalLSTM. | //| Purpose: Class of generator of trade signals based on | //| the 'LSTM AI' indicator. | //| Is derived from the CExpertSignalCustom class. | //+------------------------------------------------------------------+ class CSignalLSTM : public CExpertSignalAIBase { protected: //--- "weights" of market models (0-100) int m_pattern_0; public: CSignalLSTM(void); ~CSignalLSTM(void); //--- methods of adjusting "weights" of market models void Pattern_0(int value) { m_pattern_0 = value; } void ApplyPatternWeight(int patternNumber, int weight); //--- method of creating the indicator and timeseries virtual bool InitIndicators(CIndicators *indicators); //--- methods of checking if the market models are formed virtual int LongCondition(void); virtual int ShortCondition(void); //--- event handler virtual void OnTickHandler(void); virtual void OnChartEventHandler(const int id, const long &lparam, const double &dparam, const string &sparam); }; //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CSignalLSTM::CSignalLSTM(void) : m_pattern_0(10) { ID = "Recurrent"; m_id = "LSTM"; m_folderPath = eaName + "\\" + "Neural Networks" + "\\" + "State" + "\\" + m_id + "\\"; m_fileName = m_folderPath + _Symbol + "_" + IntegerToString(_Period); m_pattern_count = 1; } //+------------------------------------------------------------------+ //| Destructor | //+------------------------------------------------------------------+ CSignalLSTM::~CSignalLSTM(void) { if(CheckPointer(Net) != POINTER_INVALID) delete Net; if(CheckPointer(TempData) != POINTER_INVALID) delete TempData; if(ObjectFind(0, "PlayPauseButton") >= 0) { ObjectDelete(0, "PlayPauseButton"); } PurgeChart(); } //+------------------------------------------------------------------+ //| Create indicators. | //+------------------------------------------------------------------+ bool CSignalLSTM::InitIndicators(CIndicators *indicators) { if(m_isInitialized) return true; // Already initialized if(indicators == NULL) return false; if(!CExpertSignalCustom::InitIndicators(indicators)) return false; if(!CExpertSignalAIBase::InitIndicators(indicators)) return false; /* // Create a button on the top left corner of the chart // Define constants for colors color clrButtonBackground = clrSilver; // Background color of the button color clrButtonText = clrBlack; // Text color of the button color clrButtonBorder = clrBlack; // Border color of the button long x = ChartGetInteger(0, CHART_WIDTH_IN_PIXELS) - 0; // Adjust x position as needed int y = 290; // Adjust y position as needed m_buttonHandle = ObjectCreate(0, "PlayPauseButton", OBJ_BUTTON, 0, 0, 0); ObjectSetInteger(0, "PlayPauseButton", OBJPROP_XDISTANCE, x); ObjectSetInteger(0, "PlayPauseButton", OBJPROP_YDISTANCE, y); ObjectSetInteger(0, "PlayPauseButton", OBJPROP_CORNER, 2); ObjectSetString(0, "PlayPauseButton", OBJPROP_TEXT, "Pause"); ObjectSetInteger(0, "PlayPauseButton", OBJPROP_FONTSIZE, 12); ObjectSetInteger(0, "PlayPauseButton", OBJPROP_COLOR, clrButtonBackground); ObjectSetInteger(0, "PlayPauseButton", OBJPROP_BORDER_COLOR, clrButtonBorder); ObjectSetInteger(0, "PlayPauseButton", OBJPROP_BORDER_TYPE, BORDER_FLAT);*/ Net = new CNet(NULL); if(CheckPointer(Net) == POINTER_INVALID) return false; m_fileName += "_" + DoubleToString(MathRound(m_outputNeuronsCount)) + "_" + DoubleToString(MathRound(m_optimizationAlgo)); if(trainingMode && FileIsExist(m_fileName + ".nnw", FILE_COMMON)) { int attempts = 0; while(attempts < 5) { if(FileDelete(m_fileName + ".nnw", FILE_COMMON) && FileDelete(m_fileName + ".cfg", FILE_COMMON)) { return true; } if(GetLastError() != 5002) // If the error is not because the file does not exist { Print(__FUNCTION__ + ": Retry " + IntegerToString(attempts + 1) + " failed to delete file: " + m_fileName); } else { ResetLastError(); // Reset the error code } Sleep(1000); attempts++; } if(GetLastError() != 5002) // If the error is not because the file does not exist { Print(__FUNCTION__ + ": Failed to delete file after retries: " + m_fileName + ".nnw"); } return false; } if(!LoadAndCompareTopologyConfiguration(m_fileName, m_initialNeuronsCount, m_hiddenLayersCount, m_neuronsReduction, m_minNeuronsCount, m_optimizationAlgo, m_historyBars, m_outputNeuronsCount, m_neuronsCount, m_studyPeriod, m_minTrainYear, m_isInitialized, m_stopTrainWR, m_fractalPeriods)) { SaveTopologyConfiguration(m_fileName, m_initialNeuronsCount, m_hiddenLayersCount, m_neuronsReduction, m_minNeuronsCount, m_optimizationAlgo, m_historyBars, m_outputNeuronsCount, m_neuronsCount, m_studyPeriod, m_minTrainYear, m_isInitialized, m_stopTrainWR, m_fractalPeriods); } if(!Net.Load(m_fileName + ".nnw", dError, dUndefine, dForecast, dtStudied, true)) { int error_code = GetLastError(); if(error_code != 5004) { printf("%s -> Error loading previous network %s error code : %d", __FUNCTION__, m_fileName, error_code); ResetLastError(); } CArrayObj *Topology = new CArrayObj(); if(CheckPointer(Topology) == POINTER_INVALID) return false; // Input Layer CLayerDescription *desc = new CLayerDescription(); if(CheckPointer(desc) == POINTER_INVALID) return false; desc.count = m_historyBars * m_neuronsCount; desc.type = defNeuron; desc.activation = NONE; desc.optimization = (m_optimizationAlgo == 0) ? SGD : ADAM; if(!Topology.Add(desc)) { delete Topology; return false; } // LSTM Layer desc = new CLayerDescription(); if(CheckPointer(desc) == POINTER_INVALID) return false; desc.count = m_hiddenLayersCount; desc.type = defNeuronLSTM; desc.activation = TANH; desc.optimization = (m_optimizationAlgo == 0) ? SGD : ADAM; desc.window = (int)m_historyBars * m_neuronsCount; desc.step = (int)m_historyBars / 2; if(!Topology.Add(desc)) { delete Topology; return false; } // Hidden Layers int n = m_initialNeuronsCount; bool result = true; for(int i = 0; (i < m_hiddenLayersCount && result); i++) { // Fully Connected Layer desc = new CLayerDescription(); if(CheckPointer(desc) == POINTER_INVALID) return false; desc.count = n; desc.type = defNeuron; desc.activation = TANH; desc.optimization = (m_optimizationAlgo == 0) ? SGD : ADAM; result = (Topology.Add(desc) && result); n = (int)MathMax(n * (m_neuronsReduction * 0.01), m_minNeuronsCount); } if(!result) { delete Topology; return false; } // Output Layer desc = new CLayerDescription(); if(CheckPointer(desc) == POINTER_INVALID) return false; desc.count = m_outputNeuronsCount; desc.type = defNeuron; desc.activation = TANH; desc.optimization = (m_optimizationAlgo == 0) ? SGD : ADAM; if(!Topology.Add(desc)) { delete Topology; return false; } delete Net; Net = new CNet(Topology); delete Topology; if(CheckPointer(Net) == POINTER_INVALID) return false; } TempData = new CArrayDouble(); if(CheckPointer(TempData) == POINTER_INVALID) return false; bEventStudy = EventChartCustom(ChartID(), 1, (long)MathMax(0, MathMin(iTime(_Symbol, PERIOD_CURRENT, (int)(100 * Net.recentAverageSmoothingFactor * (dForecast >= m_stopTrainWR ? 1 : 10))), dtStudied)), 0, "Init"); m_isInitialized = true; return true; } //+------------------------------------------------------------------+ //| OnTick function | //+------------------------------------------------------------------+ void CSignalLSTM::OnTickHandler() { if(!bEventStudy && (dPrevSignal == -2 || dtStudied < SeriesInfoInteger(m_symbol.Name(), m_period, SERIES_LASTBAR_DATE))) bEventStudy = EventChartCustom(ChartID(), 1, (long)MathMax(0, MathMin(iTime(m_symbol.Name(), PERIOD_CURRENT, (int)(100 * Net.recentAverageSmoothingFactor * (dForecast >= m_stopTrainWR ? 1 : 10))), dtStudied)), 0, "New Bar"); Comment(StringFormat("Event call %s; PrevSignal %.5f; Model trained %s -> %s", (string)bEventStudy, dPrevSignal, TimeToString(dtStudied), TimeToString(SeriesInfoInteger(m_symbol.Name(), PERIOD_CURRENT, SERIES_LASTBAR_DATE)))); } //+------------------------------------------------------------------+ //| | //+------------------------------------------------------------------+ void CSignalLSTM::OnChartEventHandler(const int id, const long & lparam, const double & dparam, const string & sparam) { /* if(id == CHARTEVENT_OBJECT_CLICK && lparam == m_buttonHandle) { m_isPlaying = !m_isPlaying; if(m_isPlaying) { ObjectSetString(0, "PlayPauseButton", OBJPROP_TEXT, "Pause"); } else { ObjectSetString(0, "PlayPauseButton", OBJPROP_TEXT, "Play"); } }*/ if(id == 1001) { Train(lparam); bEventStudy = false; OnTickHandler(); } } //+------------------------------------------------------------------+ //| "Voting" that price will grow. | //+------------------------------------------------------------------+ int CSignalLSTM::LongCondition(void) { int result = 0; if(DoubleToSignal(dPrevSignal) == Buy) { result = m_pattern_0; m_active_pattern = "Pattern_0"; m_active_direction = "Buy"; } //--- return the result return(result); } //+------------------------------------------------------------------+ //| "Voting" that price will fall. | //+------------------------------------------------------------------+ int CSignalLSTM::ShortCondition(void) { int result = 0; if(DoubleToSignal(dPrevSignal) == Sell) { result = m_pattern_0; m_active_pattern = "Pattern_0"; m_active_direction = "Sell"; } // Return the result return result; } //+------------------------------------------------------------------+ //| Set the specified pattern's weight to the specified value | //+------------------------------------------------------------------+ void CSignalLSTM::ApplyPatternWeight(int patternNumber, int weight) { switch(patternNumber) { default: break; case 0: Pattern_0(weight); break; } } //+------------------------------------------------------------------+