EasySbAi/Models/OnnxDef.mqh
Nique_372 c2869c89a8
2026-04-10 21:47:52 -05:00

266 líneas
19 KiB
MQL5

//+------------------------------------------------------------------+
//| OnnxModel.mqh |
//| Copyright 2025, Niquel Mendoza. |
//| https://www.mql5.com/es/users/nique_372 |
//+------------------------------------------------------------------+
#property copyright "Copyright 2025, Niquel Mendoza."
#property link "https://www.mql5.com/es/users/nique_372"
#property strict
#ifndef EASYSB_MODELS_ONNXDEF_MQH
#define EASYSB_MODELS_ONNXDEF_MQH
//+------------------------------------------------------------------+
//| Include |
//+------------------------------------------------------------------+
#include <TSN\\MQLArticles\\Utils\\Basic.mqh>
//+------------------------------------------------------------------+
//| Defines |
//+------------------------------------------------------------------+
#define SIMPLE_MODEL_ONNX_CLASES 2
//+------------------------------------------------------------------+
//| Clase para clasificacion |
//+------------------------------------------------------------------+
class COnnxSimpleModel : public CLoggerBase
{
private:
long m_onnx_handle;
struct Map
{
long key[];
float value[];
};
public:
COnnxSimpleModel(void) : m_onnx_handle(INVALID_HANDLE) {}
~COnnxSimpleModel(void) {}
//---
bool Init(const uchar& data[], int input_num);
bool Init(const string& file_name, bool comon, int input_num);
bool Predict(vector &features, vector& probas, ulong& best_class_index, long &out_best_class, long onnx_exe_flags);
};
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
bool COnnxSimpleModel::Init(const uchar &data[], int input_num)
{
//---
ResetLastError();
m_onnx_handle = OnnxCreateFromBuffer(data, 0);
if(m_onnx_handle == INVALID_HANDLE)
{
LogError(StringFormat("Fallo al crear el modelo onnx por buffer, ultimo error = %d", GetLastError()), FUNCION_ACTUAL);
return false;
}
//---
ulong input_shape[] = {1, input_num};
ResetLastError();
if(!OnnxSetInputShape(m_onnx_handle, 0, input_shape))
{
LogError(StringFormat("Fallo al setear el input shape el modelo onnx por buffer [f=%d], ultimo error = %d", input_num, GetLastError()), FUNCION_ACTUAL);
return false;
}
//---
ulong ouput_shape_label[] = {1}; // 0=N (valor dincamiso de muestras) | 1 = solo un valor por muestra (como un matrix de 1 col)
OnnxSetOutputShape(m_onnx_handle, 0, ouput_shape_label);
ulong ouput_shape_map[] = {1, 2}; //0 = N (netron no lo espeifica por lo isot) | 2 = key y float (aunque aqui tambine es complejo)
OnnxSetOutputShape(m_onnx_handle, 1, ouput_shape_map);
//---
return true;
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
bool COnnxSimpleModel::Init(const string &file_name, bool comon, int input_num)
{
//---
ResetLastError();
m_onnx_handle = OnnxCreate(file_name, (comon ? ONNX_COMMON_FOLDER : 0));
if(m_onnx_handle == INVALID_HANDLE)
{
LogError(StringFormat("Fallo al crear el modelo onnx por file = %s, ultimo error = %d", file_name, GetLastError()), FUNCION_ACTUAL);
return false;
}
//---
ulong input_shape[] = {1, input_num};
ResetLastError();
if(!OnnxSetInputShape(m_onnx_handle, 0, input_shape))
{
LogError(StringFormat("Fallo al setear el input shape el modelo onnx por buffer [f=%d], ultimo error = %d", input_num, GetLastError()), FUNCION_ACTUAL);
return false;
}
//---
ulong ouput_shape_label[] = {1}; // 0=N (valor dincamiso de muestras) | 1 = solo un valor por muestra (como un matrix de 1 col)
OnnxSetOutputShape(m_onnx_handle, 0, ouput_shape_label);
ulong ouput_shape_map[] = {1, 2}; //0 = N (netron no lo espeifica por lo isot) | 2 = key y float (aunque aqui tambine es complejo)
OnnxSetOutputShape(m_onnx_handle, 1, ouput_shape_map);
//---
return true;
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
bool COnnxSimpleModel::Predict(vector &features, vector& probas, ulong& best_class_index, long &out_best_class, long onnx_exe_flags)
{
static vectorf x_float;
x_float.Assign(features);
//--- Label
long out_put_data[];
ArrayResize(out_put_data, 1);
//--- Mapa
Map mapa_out[];
ArrayResize(mapa_out, 1);
ArrayResize(mapa_out[0].key, SIMPLE_MODEL_ONNX_CLASES);
ArrayResize(mapa_out[0].value, SIMPLE_MODEL_ONNX_CLASES);
//---
ResetLastError();
if(!OnnxRun(m_onnx_handle, onnx_exe_flags, x_float, out_put_data, mapa_out))
{
LogFatalError(StringFormat("No se pudo ejecutar el modelo onnx, ultimo error = %d", GetLastError()), FUNCION_ACTUAL);
out_best_class = -1;
probas.Resize(0);
best_class_index = -1;
return false;
}
//---
probas.Assign(mapa_out[0].value);
// VectorPrint(v, 4);
best_class_index = probas.ArgMax();
out_best_class = mapa_out[0].key[best_class_index]; // Clave maxcimo valor
return true;
}
//+------------------------------------------------------------------+
//| Clase para modelos de regresion |
//+------------------------------------------------------------------+
class COnnxModelRegresoin : public CLoggerBase
{
private:
long m_onnx_handle;
public:
COnnxModelRegresoin(void) : m_onnx_handle(INVALID_HANDLE) {}
~COnnxModelRegresoin(void) { if(m_onnx_handle != INVALID_HANDLE) OnnxRelease(m_onnx_handle); }
//---
bool Init(const uchar& data[], int input_num);
bool Init(const string& file_name, bool comon, int input_num);
bool Predict(vector& entrada, long onnx_exe_flags, double& output);
};
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
bool COnnxModelRegresoin::Init(const uchar &data[], int input_num)
{
//---
ResetLastError();
m_onnx_handle = OnnxCreateFromBuffer(data, 0);
if(m_onnx_handle == INVALID_HANDLE)
{
LogError(StringFormat("Fallo al crear el modelo onnx por buffer, ultimo error = %d", GetLastError()), FUNCION_ACTUAL);
return false;
}
//---
ulong input_shape[] = {1, input_num};
ResetLastError();
if(!OnnxSetInputShape(m_onnx_handle, 0, input_shape))
{
LogError(StringFormat("Fallo al setear el input shape el modelo onnx por buffer [f=%d], ultimo error = %d", input_num, GetLastError()), FUNCION_ACTUAL);
return false;
}
//---
ulong ouput_shape_label[] = {1}; // 1=N (valor dincamiso de muestras) solo queremos 1 salida
OnnxSetOutputShape(m_onnx_handle, 0, ouput_shape_label);
//---
return true;
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
bool COnnxModelRegresoin::Init(const string &file_name,bool comon,int input_num)
{
//---
ResetLastError();
m_onnx_handle = OnnxCreate(file_name, (comon ? ONNX_COMMON_FOLDER : 0));
if(m_onnx_handle == INVALID_HANDLE)
{
LogError(StringFormat("Fallo al crear el modelo onnx por file = %s, ultimo error = %d", file_name, GetLastError()), FUNCION_ACTUAL);
return false;
}
//---
ulong input_shape[] = {1, input_num};
ResetLastError();
if(!OnnxSetInputShape(m_onnx_handle, 0, input_shape))
{
LogError(StringFormat("Fallo al setear el input shape el modelo onnx por buffer [f=%d], ultimo error = %d", input_num, GetLastError()), FUNCION_ACTUAL);
return false;
}
//---
ulong ouput_shape_label[] = {1}; // 1=N (valor dincamiso de muestras) solo queremos 1 salida
OnnxSetOutputShape(m_onnx_handle, 0, ouput_shape_label);
//---
return true;
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
bool COnnxModelRegresoin::Predict(vector& entrada, long onnx_exe_flags, double& output)
{
//---
static matrixf x_float;
x_float.Assign(entrada);
// Print(x_float.Cols());
//Print(x_float.Rows());
//---
static matrixf out;
out.Resize(1, 1);
//---
ResetLastError();
if(!OnnxRun(m_onnx_handle, onnx_exe_flags, x_float, out))
{
LogError(StringFormat("Fallo al eejeuctar la prediccion del modelo onnx, ultimo error = %d", GetLastError()), FUNCION_ACTUAL);
return false;
}
//---
output = out[0][0];
return true;
}
//+------------------------------------------------------------------+
#endif //EASYSB_MODELS_ONNXDEF_MQH
//+------------------------------------------------------------------+