Scikit.Classification.ONNX/Scikit.Classification.ONNX.mq5

421 lines
28 KiB
MQL5
Raw Permalink Normal View History

2025-05-30 16:23:21 +02:00
<EFBFBD><EFBFBD>//+------------------------------------------------------------------+
//| Scikit.Classification.ONNX.mq5 |
//| Copyright 2023, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, MetaQuotes Ltd."
#property link "https://www.mql5.com"
#property version "1.00"
#include "iris.mqh"
#resource "models\\adaboost_iris.onnx" as const uchar ExtModelAdaBoost[];
#resource "models\\bnb_classifier_iris.onnx" as const uchar ExtModelBNB[];
#resource "models\\bagging_iris.onnx" as const uchar ExtModelBagging[];
#resource "models\\categorical_nb_iris.onnx" as const uchar ExtModelCategorialNB[];
#resource "models\\cnb_classifier_iris.onnx" as const uchar ExtModelCNB[];
#resource "models\\decision_tree_iris.onnx" as const uchar ExtModelDecisionTree[];
#resource "models\\extra_tree_iris.onnx" as const uchar ExtModelExtraTree[];
#resource "models\\extra_trees_iris.onnx" as const uchar ExtModelExtraTrees[];
#resource "models\\gnb_classifier_iris.onnx" as const uchar ExtModelGNB[];
#resource "models\\gb_iris.onnx" as const uchar ExtModelGB[];
#resource "models\\hist_gradient_boosting_classifier_iris.onnx" as const uchar ExtModelHistGB[];
#resource "models\\knn_iris.onnx" as const uchar ExtModelKNN[];
#resource "models\\lda_classifier_iris.onnx" as const uchar ExtModelLDA[];
#resource "models\\linear_svc_iris.onnx" as const uchar ExtModelLinearSVC[];
#resource "models\\logistic_regression_iris.onnx" as const uchar ExtModelLogisticRegression[];
#resource "models\\logistic_regressioncv_iris.onnx" as const uchar ExtModelLogisticRegressionCV[];
#resource "models\\mlp_classifier_iris.onnx" as const uchar ExtModelMLP[];
#resource "models\\mnb_classifier_iris.onnx" as const uchar ExtModelMNB[];
#resource "models\\nusvc_iris.onnx" as const uchar ExtModelNuSVC[];
#resource "models\\pa_classifier_iris.onnx" as const uchar ExtModelPA[];
#resource "models\\perceptron_classifier_iris.onnx" as const uchar ExtModelPerceptron[];
#resource "models\\radius_neighbors_iris.onnx" as const uchar ExtModelRN[];
#resource "models\\rf_iris.onnx" as const uchar ExtModelRF[];
#resource "models\\ridge_classifier_iris.onnx" as const uchar ExtModelRidge[];
#resource "models\\ridge_classifier_cv_iris.onnx" as const uchar ExtModelRidgeCV[];
#resource "models\\sgd_classifier_iris.onnx" as const uchar ExtModelSGD[];
#resource "models\\svc_iris.onnx" as const uchar ExtModelSVC[];
#define ModelAdaBoost 1
#define ModelBNB 2
#define ModelBagging 3
#define ModelCategorialNB 4
#define ModelCNB 5
#define ModelDecisionTree 6
#define ModelExtraTree 7
#define ModelExtraTrees 8
#define ModelGNB 9
#define ModelGB 10
#define ModelHistGB 11
#define ModelKNN 12
#define ModelLDA 13
#define ModelLinearSVC 14
#define ModelLogisticRegression 15
#define ModelLogisticRegressionCV 16
#define ModelMLP 17
#define ModelMNB 18
#define ModelNuSVC 19
#define ModelPA 20
#define ModelPerceptron 21
#define ModelRN 22
#define ModelRF 23
#define ModelRidge 24
#define ModelRidgeCV 25
#define ModelSGD 26
#define ModelSVC 27
//+------------------------------------------------------------------+
//| CreateModel |
//+------------------------------------------------------------------+
long CreateModel(const int model_id,string &model_name)
{
long model_handle=INVALID_HANDLE;
ulong flags=ONNX_DEFAULT;
//ulong flags=ONNX_DEBUG_LOGS;
//---
switch(model_id)
{
case ModelAdaBoost:
{
model_name="Adaptive Boosting";
model_handle=OnnxCreateFromBuffer(ExtModelAdaBoost,flags);
break;
}
case ModelBNB:
{
model_name="Bernoulli Naive Bayes";
model_handle=OnnxCreateFromBuffer(ExtModelBNB,flags);
break;
}
case ModelBagging:
{
model_name="Bootstrap Aggregating";
model_handle=OnnxCreateFromBuffer(ExtModelBagging,flags);
break;
}
case ModelCategorialNB:
{
model_name="!ategorical Naive Bayes";
model_handle=OnnxCreateFromBuffer(ExtModelCategorialNB,flags);
break;
}
case ModelCNB:
{
model_name="Complement Naive Bayes";
model_handle=OnnxCreateFromBuffer(ExtModelCNB,flags);
break;
}
case ModelDecisionTree:
{
model_name="Decision Tree";
model_handle=OnnxCreateFromBuffer(ExtModelDecisionTree,flags);
break;
}
case ModelExtraTree:
{
model_name="Extra Tree";
model_handle=OnnxCreateFromBuffer(ExtModelExtraTree,flags);
break;
}
case ModelExtraTrees:
{
model_name="Extra Trees";
model_handle=OnnxCreateFromBuffer(ExtModelExtraTrees,flags);
break;
}
case ModelGNB:
{
model_name="Gaussian Naive Bayes";
model_handle=OnnxCreateFromBuffer(ExtModelGNB,flags);
break;
}
case ModelGB:
{
model_name="Gradient Boosting";
model_handle=OnnxCreateFromBuffer(ExtModelGB,flags);
break;
}
case ModelHistGB:
{
model_name="Histogram-Based Gradient Boosting";
model_handle=OnnxCreateFromBuffer(ExtModelHistGB,flags);
break;
}
case ModelKNN:
{
model_name="K-NN";
model_handle=OnnxCreateFromBuffer(ExtModelKNN,flags);
break;
}
case ModelLDA:
{
model_name="Linear Discriminant Analysis";
model_handle=OnnxCreateFromBuffer(ExtModelLDA,flags);
break;
}
case ModelLinearSVC:
{
model_name="LinearSVC";
model_handle=OnnxCreateFromBuffer(ExtModelLinearSVC,flags);
break;
}
case ModelLogisticRegression:
{
model_name="Logistic Regression";
model_handle=OnnxCreateFromBuffer(ExtModelLogisticRegression,flags);
break;
}
case ModelLogisticRegressionCV:
{
model_name="Logistic RegressionCV";
model_handle=OnnxCreateFromBuffer(ExtModelLogisticRegressionCV,flags);
break;
}
case ModelMLP:
{
model_name="MLP Classifier";
model_handle=OnnxCreateFromBuffer(ExtModelMLP,flags);
break;
}
case ModelMNB:
{
model_name="Multinomial Naive Bayes";
model_handle=OnnxCreateFromBuffer(ExtModelMNB,flags);
break;
}
case ModelNuSVC:
{
model_name="NuSVC";
model_handle=OnnxCreateFromBuffer(ExtModelNuSVC,flags);
break;
}
case ModelPA:
{
model_name="Passive-Aggressive";
model_handle=OnnxCreateFromBuffer(ExtModelPA,flags);
break;
}
case ModelPerceptron:
{
model_name="Perceptron";
model_handle=OnnxCreateFromBuffer(ExtModelPerceptron,flags);
break;
}
case ModelRN:
{
model_name="Radius Neighbors";
model_handle=OnnxCreateFromBuffer(ExtModelRN,flags);
break;
}
case ModelRF:
{
model_name="Random Forest";
model_handle=OnnxCreateFromBuffer(ExtModelRF,flags);
break;
}
case ModelRidge:
{
model_name="Ridge";
model_handle=OnnxCreateFromBuffer(ExtModelRidge,flags);
break;
}
case ModelRidgeCV:
{
model_name="RidgeCV";
model_handle=OnnxCreateFromBuffer(ExtModelRidgeCV,flags);
break;
}
case ModelSGD:
{
model_name="SGD";
model_handle=OnnxCreateFromBuffer(ExtModelSGD,flags);
break;
}
case ModelSVC:
{
model_name="SVC";
model_handle=OnnxCreateFromBuffer(ExtModelSVC,flags);
break;
}
default:
break;
}
//---
return(model_handle);
}
//+------------------------------------------------------------------+
//| TestSampleSequenceMapOutput |
//+------------------------------------------------------------------+
bool TestSampleSequenceMapOutput(long model,sIRISsample &iris_sample, int &model_class_id)
{
//---
model_class_id=-1;
float input_data[1][4];
for(int k=0; k<4; k++)
{
input_data[0][k]=(float)iris_sample.features[k];
}
//---
float out1[];
//---
struct Map
{
ulong key[];
float value[];
} out2[];
//---
bool res=ArrayResize(out1,input_data.Range(0))==input_data.Range(0);
//---
if(res)
{
ulong input_shape[]= { input_data.Range(0), input_data.Range(1) };
ulong output_shape[]= { input_data.Range(0) };
//---
OnnxSetInputShape(model,0,input_shape);
OnnxSetOutputShape(model,0,output_shape);
//---
res=OnnxRun(model,0,input_data,out1,out2);
//---
if(res)
{
//--- postprocessing of sequence map data
//--- find class with maximum probability
ulong output_keys[];
float output_values[];
//---
model_class_id=-1;
int max_idx=-1;
float max_value=-1;
//---
for(uint n=0; n<out2.Size(); n++)
{
//--- copy to arrays
ArrayCopy(output_keys,out2[n].key);
ArrayCopy(output_values,out2[n].value);
//--- find the key with maximum probability
for(int k=0; k<ArraySize(output_values); k++)
{
if(k==0)
{
max_idx=0;
max_value=output_values[max_idx];
model_class_id=(int)output_keys[max_idx];
}
else
{
if(output_values[k]>max_value)
{
max_idx=k;
max_value=output_values[max_idx];
model_class_id=(int)output_keys[max_idx];
}
}
}
}
}
}
//---
return(res);
}
//+------------------------------------------------------------------+
//| TestSampleTensorOutput |
//+------------------------------------------------------------------+
bool TestSampleTensorOutput(long model,sIRISsample &iris_sample, int &model_class_id)
{
//---
model_class_id=-1;
float input_data[1][4];
for(int k=0; k<4; k++)
{
input_data[0][k]=(float)iris_sample.features[k];
}
//---
ulong input_shape[]= { 1, 4};
OnnxSetInputShape(model,0,input_shape);
//---
int output1[1];
float output2[1,3];
//---
ulong output_shape[]= {1};
OnnxSetOutputShape(model,0,output_shape);
//---
ulong output_shape2[]= {1,3};
OnnxSetOutputShape(model,1,output_shape2);
//---
bool res=OnnxRun(model,0,input_data,output1,output2);
//--- :;0AA 4;O MB8E <>45;59 2 output1[0];
if(res)
model_class_id=output1[0];
//---
return(res);
}
//+------------------------------------------------------------------+
//| Script program start function |
//+------------------------------------------------------------------+
int OnStart(void)
{
sIRISsample iris_samples[];
//--- load dataset from file
PrepareIrisDataset(iris_samples);
//--- test
int total_samples=ArraySize(iris_samples);
if(total_samples==0)
{
Print("error in loading iris dataset from iris.csv");
return(false);
}
/*for(int k=0; k<total_samples; k++)
{
PrintFormat("%d (%.2f,%.2f,%.2f,%.2f) class %d (%s)",iris_samples[k].sample_id,iris_samples[k].features[0],iris_samples[k].features[1],iris_samples[k].features[2],iris_samples[k].features[3],iris_samples[k].class_id,iris_samples[k].class_name);
}*/
//----
//--- test all iris dataset sample by sample
for(int i=1; i<=27; i++)
{
string model_name="";
//---
long model=CreateModel(i,model_name);
if(model==INVALID_HANDLE)
{
PrintFormat("model_name=%s OnnxCreate error %d for",model_name,GetLastError());
}
else
{
//--- check all samples
int correct_results=0;
for(int k=0; k<total_samples; k++)
{
int model_class_id=-1;
//--- select data output processor
if(i==ModelLinearSVC || i==ModelNuSVC || i==ModelSVC || i==ModelRN || i==ModelRidge || i==ModelRidgeCV)
{
TestSampleTensorOutput(model,iris_samples[k],model_class_id);
}
else
{
TestSampleSequenceMapOutput(model,iris_samples[k],model_class_id);
}
//---
if(model_class_id==iris_samples[k].class_id)
{
correct_results++;
//PrintFormat("sample=%d OK [class=%d]",iris_samples[k].sample_id,model_class_id);
}
else
{
//PrintFormat("model:%s sample=%d FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f]",model_name,iris_samples[k].sample_id,model_class_id,iris_samples[k].class_id,iris_samples[k].features[0],iris_samples[k].features[1],iris_samples[k].features[2],iris_samples[k].features[3]);
}
}
PrintFormat("%d model:%s accuracy: %.4f",i,model_name,1.0*correct_results/total_samples);
//--- release model
OnnxRelease(model);
}
//---
}
return(0);
}