ONNX.Price.Prediction/ONNX.Price.Prediction.Test3.mq5

176 lines
14 KiB
MQL5
Raw Permalink Normal View History

2024-08-08 23:31:54 +02:00
<EFBFBD><EFBFBD>//+------------------------------------------------------------------+
//| ONNX.Price.Prediction.Test.mq5 |
//| Copyright 2023, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, MetaQuotes Ltd."
#property link "https://www.mql5.com"
#property version "1.00"
#property description "Evaluation of the quality of the next Close price prediction by Python\\model.onnx.\n"
"Start in strategy tester on EURUSD,H1, open prices, from 2023.01.01 to 2023.02.01"
#resource "Python/model3.onnx" as uchar ExtModel[]
#define SAMPLE_SIZE 10 // bars count in sample as defined in Python\PricePredictionTraining.py
// X, y = collect_dataset(df, history_size=10)
long ExtHandle=INVALID_HANDLE;
double ExtPredicted=0;
datetime ExtNextBar=0;
long ExtTests=0;
long ExtRightDirection=0;
double ExtSumAbsoluteError=0.0;
//+------------------------------------------------------------------+
//| Expert initialization function |
//+------------------------------------------------------------------+
int OnInit()
{
if(Symbol()!="EURUSD" || Period()!=PERIOD_H1)
{
Print("model must work with EURUSD,H1");
return(INIT_FAILED);
}
//--- create a model from static buffer
ExtHandle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);
if(ExtHandle==INVALID_HANDLE)
{
Print("OnnxCreateFromBuffer error ",GetLastError());
return(INIT_FAILED);
}
//--- since not all sizes defined in the input tensor we must set them explicitly
//--- first index - batch size, second index - series size, third index - number of series (OHLC)
const long input_shape[] = {1,SAMPLE_SIZE,4};
if(!OnnxSetInputShape(ExtHandle,0,input_shape))
{
Print("OnnxSetInputShape error ",GetLastError());
return(INIT_FAILED);
}
//--- since not all sizes defined in the output tensor we must set them explicitly
//--- first index - batch size, must match the batch size of the input tensor
//--- second index - number of predicted prices (we only predict Close)
const long output_shape[] = {1,1};
if(!OnnxSetOutputShape(ExtHandle,0,output_shape))
{
Print("OnnxSetOutputShape error ",GetLastError());
return(INIT_FAILED);
}
//---
return(INIT_SUCCEEDED);
}
//+------------------------------------------------------------------+
//| Expert deinitialization function |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
{
if(ExtHandle!=INVALID_HANDLE)
{
OnnxRelease(ExtHandle);
ExtHandle=INVALID_HANDLE;
}
}
//+------------------------------------------------------------------+
//| Expert tick function |
//+------------------------------------------------------------------+
void OnTick()
{
//--- check new bar
if(TimeCurrent()<ExtNextBar)
return;
ExtTests++;
//--- set next bar time
ExtNextBar=TimeCurrent();
ExtNextBar-=ExtNextBar%PeriodSeconds();
ExtNextBar+=PeriodSeconds();
//--- check predicted price
CheckPredicted();
//--- predict next price
PredictPrice();
}
//+------------------------------------------------------------------+
//| Check predicted price |
//+------------------------------------------------------------------+
void CheckPredicted(void)
{
if(ExtPredicted!=0.0)
{
static double highs[3];
static double lows[3];
static datetime times[3];
double hhlls[2] = {0.0};
if(CopyHigh(Symbol(),Period(),1,3,highs)==3 && CopyLow(Symbol(),Period(),1,3,lows)==3 && CopyTime(Symbol(),Period(),1,3,times)==3)
{
hhlls[0] = (highs[1] - highs[0]) + (lows[1] - lows[0]);
hhlls[1] = (highs[2] - highs[1]) + (lows[2] - lows[1]);
ExtSumAbsoluteError+=MathAbs(ExtPredicted-hhlls[1]);
double delta_predict=ExtPredicted > 0.0 ? 1 : ExtPredicted < 0.0 ? -1 : 0.0;
double delta_actual=hhlls[1] > 0.0 ? 1 : hhlls[1] < 0.0 ? -1 : 0.0;
if(delta_predict == delta_actual)
ExtRightDirection++;
Print("----------------------------------");
Print("times[0]: ", times[0], "; highs[0]: ", highs[0], "; lows[0]: ", lows[0], "times[1]: ", times[1], "; highs[1]: ", highs[1], "; lows[1]: ", lows[1], "; hhlls[0]: ", hhlls[0], "; times[2]: ", times[2], "; highs[2]: ", highs[2], "; lows[2]: ", lows[2], "; hhlls[1]: ", hhlls[1]);
Print("ExtPredicted: ", ExtPredicted, "; ExtSumAbsoluteError (ExtPredicted-hhlls[1]): ", ExtSumAbsoluteError, "; delta_predict: ", delta_predict, "; delta_actual: ", delta_actual, "; ExtRightDirection: ", ExtRightDirection, "; ExtTests: ", ExtTests);
}
}
}
//+------------------------------------------------------------------+
//| Predict next price |
//+------------------------------------------------------------------+
void PredictPrice(void)
{
static matrixf input_data(SAMPLE_SIZE,4); // matrix for prepared input data
static vectorf output_data(1); // vector to get result
static matrix mm(SAMPLE_SIZE,4); // matrix of horizontal vectors Mean
static matrix ms(SAMPLE_SIZE,4); // matrix of horizontal vectors Std
static matrix x_norm(SAMPLE_SIZE,4); // matrix for prices normalize
//--- prepare input data
matrix rates;
//--- request last bars
if(!rates.CopyRates("EURUSD",PERIOD_H1,COPY_RATES_OHLC,1,SAMPLE_SIZE))
{
ExtPredicted=0.0;
return;
}
//--- get series Mean
vector m=rates.Mean(1);
//--- get series Std
vector s=rates.Std(1);
//--- prepare matrices for prices normalization
for(int i=0; i<SAMPLE_SIZE; i++)
{
mm.Row(m,i);
ms.Row(s,i);
}
//--- the input of the model must be a set of vertical OHLC vectors
x_norm=rates.Transpose();
//--- normalize prices
x_norm-=mm;
x_norm/=ms;
//--- run the inference
input_data.Assign(x_norm);
if(!OnnxRun(ExtHandle,ONNX_NO_CONVERSION,input_data,output_data))
{
ExtPredicted=0.0;
return;
}
//--- denormalize the price from the output value
ExtPredicted=output_data[0] /* * ((s[1] - s[0]) - (s[0] - s[2])) + ((m[1] - m[0]) - (m[0] - m[2]))*/ ;
}
//+------------------------------------------------------------------+
//| Tester function |
//+------------------------------------------------------------------+
double OnTester()
{
double mae=ExtSumAbsoluteError/ExtTests;
Print("mae = ",mae);
double right_directions=(ExtRightDirection*100.0)/ExtTests;
PrintFormat("right_directions = %.2f%%",right_directions);
//---
return(right_directions);
}
//+------------------------------------------------------------------+