Scikit.Classification.ONNX/Iris_RidgeClassifier.mq5
super.admin b7e9325a38 convert
2025-05-30 16:23:21 +02:00

222 lines
8.8 KiB
MQL5

//+------------------------------------------------------------------+
//| Iris_RidgeClassifier.mq5 |
//| Copyright 2023, MetaQuotes Ltd. |
//| https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, MetaQuotes Ltd."
#property link "https://www.mql5.com"
#property version "1.00"
#include "iris.mqh"
#resource "models\\ridge_classifier_iris.onnx" as const uchar ExtModel[];
//+------------------------------------------------------------------+
//| Test IRIS dataset samples |
//+------------------------------------------------------------------+
bool TestSamples(long model,float &input_data[][4], int &model_classes_id[])
{
//--- check number of input samples
ulong batch_size=input_data.Range(0);
if(batch_size==0)
return(false);
//--- prepare output array
ArrayResize(model_classes_id,(int)batch_size);
//---
ulong input_shape[]= { batch_size, input_data.Range(1)};
OnnxSetInputShape(model,0,input_shape);
//---
int output1[];
float output2[][3];
//---
ArrayResize(output1,(int)batch_size);
ArrayResize(output2,(int)batch_size);
//---
ulong output_shape[]= {batch_size};
OnnxSetOutputShape(model,0,output_shape);
//---
ulong output_shape2[]= {batch_size,3};
OnnxSetOutputShape(model,1,output_shape2);
//---
bool res=OnnxRun(model,ONNX_DEBUG_LOGS,input_data,output1,output2);
//--- classes are ready in output1[k];
if(res)
{
for(int k=0; k<(int)batch_size; k++)
model_classes_id[k]=output1[k];
}
//---
return(res);
}
//+------------------------------------------------------------------+
//| Test all samples from IRIS dataset (150) |
//| Here we test all samples with batch=1, sample by sample |
//+------------------------------------------------------------------+
bool TestAllIrisDataset(const long model,const string model_name,double &model_accuracy)
{
sIRISsample iris_samples[];
//--- load dataset from file
PrepareIrisDataset(iris_samples);
//--- test
int total_samples=ArraySize(iris_samples);
if(total_samples==0)
{
Print("iris dataset not prepared");
return(false);
}
//--- show dataset
for(int k=0; k<total_samples; k++)
{
//PrintFormat("%d (%.2f,%.2f,%.2f,%.2f) class %d (%s)",iris_samples[k].sample_id,iris_samples[k].features[0],iris_samples[k].features[1],iris_samples[k].features[2],iris_samples[k].features[3],iris_samples[k].class_id,iris_samples[k].class_name);
}
//--- array for output classes
int model_output_classes_id[];
//--- check all Iris dataset samples
int correct_results=0;
for(int k=0; k<total_samples; k++)
{
//--- input array
float iris_sample_input_data[1][4];
//--- prepare input data from kth iris sample dataset
iris_sample_input_data[0][0]=(float)iris_samples[k].features[0];
iris_sample_input_data[0][1]=(float)iris_samples[k].features[1];
iris_sample_input_data[0][2]=(float)iris_samples[k].features[2];
iris_sample_input_data[0][3]=(float)iris_samples[k].features[3];
//--- run model
bool res=TestSamples(model,iris_sample_input_data,model_output_classes_id);
//--- check result
if(res)
{
if(model_output_classes_id[0]==iris_samples[k].class_id)
{
correct_results++;
}
else
{
PrintFormat("model:%s sample=%d FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f]",model_name,iris_samples[k].sample_id,model_output_classes_id[0],iris_samples[k].class_id,iris_samples[k].features[0],iris_samples[k].features[1],iris_samples[k].features[2],iris_samples[k].features[3]);
}
}
}
model_accuracy=1.0*correct_results/total_samples;
//---
PrintFormat("model:%s correct results: %.2f%%",model_name,100*model_accuracy);
//---
return(true);
}
//+------------------------------------------------------------------+
//| Here we test batch execution of the model |
//+------------------------------------------------------------------+
bool TestBatchExecution(const long model,const string model_name,double &model_accuracy)
{
model_accuracy=0;
//--- array for output classes
int model_output_classes_id[];
int correct_results=0;
int total_results=0;
bool res=false;
//--- run batch with 3 samples
float input_data_batch3[3][4]=
{
{5.1f,3.5f,1.4f,0.2f}, // iris dataset sample id=1, Iris-setosa
{6.3f,2.5f,4.9f,1.5f}, // iris dataset sample id=73, Iris-versicolor
{6.3f,2.7f,4.9f,1.8f} // iris dataset sample id=124, Iris-virginica
};
int correct_classes_batch3[3]= {0,1,2};
//--- run model
res=TestSamples(model,input_data_batch3,model_output_classes_id);
if(res)
{
//--- check result
for(int j=0; j<ArraySize(model_output_classes_id); j++)
{
//--- check result
if(model_output_classes_id[j]==correct_classes_batch3[j])
correct_results++;
else
{
PrintFormat("model:%s FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f)",model_name,model_output_classes_id[j],correct_classes_batch3[j],input_data_batch3[j][0],input_data_batch3[j][1],input_data_batch3[j][2],input_data_batch3[j][3]);
}
total_results++;
}
}
else
return(false);
//--- run batch with 10 samples
float input_data_batch10[10][4]=
{
{5.5f,3.5f,1.3f,0.2f}, // iris dataset sample id=37 (Iris-setosa)
{4.9f,3.1f,1.5f,0.1f}, // iris dataset sample id=38 (Iris-setosa)
{4.4f,3.0f,1.3f,0.2f}, // iris dataset sample id=39 (Iris-setosa)
{5.0f,3.3f,1.4f,0.2f}, // iris dataset sample id=50 (Iris-setosa)
{7.0f,3.2f,4.7f,1.4f}, // iris dataset sample id=51 (Iris-versicolor)
{6.4f,3.2f,4.5f,1.5f}, // iris dataset sample id=52 (Iris-versicolor)
{6.3f,3.3f,6.0f,2.5f}, // iris dataset sample id=101 (Iris-virginica)
{5.8f,2.7f,5.1f,1.9f}, // iris dataset sample id=102 (Iris-virginica)
{7.1f,3.0f,5.9f,2.1f}, // iris dataset sample id=103 (Iris-virginica)
{6.3f,2.9f,5.6f,1.8f} // iris dataset sample id=104 (Iris-virginica)
};
//--- correct classes for all 10 samples in the batch
int correct_classes_batch10[10]= {0,0,0,0,1,1,2,2,2,2};
//--- run model
res=TestSamples(model,input_data_batch10,model_output_classes_id);
//--- check result
if(res)
{
for(int j=0; j<ArraySize(model_output_classes_id); j++)
{
if(model_output_classes_id[j]==correct_classes_batch10[j])
correct_results++;
else
{
double f1=input_data_batch10[j][0];
double f2=input_data_batch10[j][1];
double f3=input_data_batch10[j][2];
double f4=input_data_batch10[j][3];
PrintFormat("model:%s FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f)",model_name,model_output_classes_id[j],correct_classes_batch10[j],input_data_batch10[j][0],input_data_batch10[j][1],input_data_batch10[j][2],input_data_batch10[j][3]);
}
total_results++;
}
}
else
return(false);
//--- calculate accuracy
model_accuracy=correct_results/total_results;
//---
return(res);
}
//+------------------------------------------------------------------+
//| Script program start function |
//+------------------------------------------------------------------+
int OnStart(void)
{
string model_name="RidgeClassifier";
//---
long model=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);
if(model==INVALID_HANDLE)
{
PrintFormat("model_name=%s OnnxCreate error %d for",model_name,GetLastError());
}
else
{
//--- test all dataset
double model_accuracy=0;
//-- test sample by sample execution for all Iris dataset
if(TestAllIrisDataset(model,model_name,model_accuracy))
PrintFormat("model=%s all samples accuracy=%f",model_name,model_accuracy);
else
PrintFormat("error in testing model=%s ",model_name);
//--- test batch execution for several samples
if(TestBatchExecution(model,model_name,model_accuracy))
PrintFormat("model=%s batch test accuracy=%f",model_name,model_accuracy);
else
PrintFormat("error in testing model=%s ",model_name);
//--- release model
OnnxRelease(model);
}
return(0);
}
//+------------------------------------------------------------------+