222 lines
8.8 KiB
MQL5
222 lines
8.8 KiB
MQL5
//+------------------------------------------------------------------+
|
|
//| Iris_NuSVC.mq5 |
|
|
//| Copyright 2023, MetaQuotes Ltd. |
|
|
//| https://www.mql5.com |
|
|
//+------------------------------------------------------------------+
|
|
#property copyright "Copyright 2023, MetaQuotes Ltd."
|
|
#property link "https://www.mql5.com"
|
|
#property version "1.00"
|
|
|
|
#include "iris.mqh"
|
|
#resource "models\\nusvc_iris.onnx" as const uchar ExtModel[];
|
|
|
|
//+------------------------------------------------------------------+
|
|
//| Test IRIS dataset samples |
|
|
//+------------------------------------------------------------------+
|
|
bool TestSamples(long model,float &input_data[][4], int &model_classes_id[])
|
|
{
|
|
//--- check number of input samples
|
|
ulong batch_size=input_data.Range(0);
|
|
if(batch_size==0)
|
|
return(false);
|
|
//--- prepare output array
|
|
ArrayResize(model_classes_id,(int)batch_size);
|
|
//---
|
|
ulong input_shape[]= { batch_size, input_data.Range(1)};
|
|
OnnxSetInputShape(model,0,input_shape);
|
|
//---
|
|
int output1[];
|
|
float output2[][3];
|
|
//---
|
|
ArrayResize(output1,(int)batch_size);
|
|
ArrayResize(output2,(int)batch_size);
|
|
//---
|
|
ulong output_shape[]= {batch_size};
|
|
OnnxSetOutputShape(model,0,output_shape);
|
|
//---
|
|
ulong output_shape2[]= {batch_size,3};
|
|
OnnxSetOutputShape(model,1,output_shape2);
|
|
//---
|
|
bool res=OnnxRun(model,ONNX_DEBUG_LOGS,input_data,output1,output2);
|
|
//--- classes are ready in output1[k];
|
|
if(res)
|
|
{
|
|
for(int k=0; k<(int)batch_size; k++)
|
|
model_classes_id[k]=output1[k];
|
|
}
|
|
//---
|
|
return(res);
|
|
}
|
|
//+------------------------------------------------------------------+
|
|
//| Test all samples from IRIS dataset (150) |
|
|
//| Here we test all samples with batch=1, sample by sample |
|
|
//+------------------------------------------------------------------+
|
|
bool TestAllIrisDataset(const long model,const string model_name,double &model_accuracy)
|
|
{
|
|
sIRISsample iris_samples[];
|
|
//--- load dataset from file
|
|
PrepareIrisDataset(iris_samples);
|
|
//--- test
|
|
int total_samples=ArraySize(iris_samples);
|
|
if(total_samples==0)
|
|
{
|
|
Print("iris dataset not prepared");
|
|
return(false);
|
|
}
|
|
//--- show dataset
|
|
for(int k=0; k<total_samples; k++)
|
|
{
|
|
//PrintFormat("%d (%.2f,%.2f,%.2f,%.2f) class %d (%s)",iris_samples[k].sample_id,iris_samples[k].features[0],iris_samples[k].features[1],iris_samples[k].features[2],iris_samples[k].features[3],iris_samples[k].class_id,iris_samples[k].class_name);
|
|
}
|
|
//--- array for output classes
|
|
int model_output_classes_id[];
|
|
//--- check all Iris dataset samples
|
|
int correct_results=0;
|
|
for(int k=0; k<total_samples; k++)
|
|
{
|
|
//--- input array
|
|
float iris_sample_input_data[1][4];
|
|
//--- prepare input data from kth iris sample dataset
|
|
iris_sample_input_data[0][0]=(float)iris_samples[k].features[0];
|
|
iris_sample_input_data[0][1]=(float)iris_samples[k].features[1];
|
|
iris_sample_input_data[0][2]=(float)iris_samples[k].features[2];
|
|
iris_sample_input_data[0][3]=(float)iris_samples[k].features[3];
|
|
//--- run model
|
|
bool res=TestSamples(model,iris_sample_input_data,model_output_classes_id);
|
|
//--- check result
|
|
if(res)
|
|
{
|
|
if(model_output_classes_id[0]==iris_samples[k].class_id)
|
|
{
|
|
correct_results++;
|
|
}
|
|
else
|
|
{
|
|
PrintFormat("model:%s sample=%d FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f]",model_name,iris_samples[k].sample_id,model_output_classes_id[0],iris_samples[k].class_id,iris_samples[k].features[0],iris_samples[k].features[1],iris_samples[k].features[2],iris_samples[k].features[3]);
|
|
}
|
|
}
|
|
}
|
|
model_accuracy=1.0*correct_results/total_samples;
|
|
//---
|
|
PrintFormat("model:%s correct results: %.2f%%",model_name,100*model_accuracy);
|
|
//---
|
|
return(true);
|
|
}
|
|
|
|
//+------------------------------------------------------------------+
|
|
//| Here we test batch execution of the model |
|
|
//+------------------------------------------------------------------+
|
|
bool TestBatchExecution(const long model,const string model_name,double &model_accuracy)
|
|
{
|
|
model_accuracy=0;
|
|
//--- array for output classes
|
|
int model_output_classes_id[];
|
|
int correct_results=0;
|
|
int total_results=0;
|
|
bool res=false;
|
|
|
|
//--- run batch with 3 samples
|
|
float input_data_batch3[3][4]=
|
|
{
|
|
{5.1f,3.5f,1.4f,0.2f}, // iris dataset sample id=1, Iris-setosa
|
|
{6.3f,2.5f,4.9f,1.5f}, // iris dataset sample id=73, Iris-versicolor
|
|
{6.3f,2.7f,4.9f,1.8f} // iris dataset sample id=124, Iris-virginica
|
|
};
|
|
int correct_classes_batch3[3]= {0,1,2};
|
|
//--- run model
|
|
res=TestSamples(model,input_data_batch3,model_output_classes_id);
|
|
if(res)
|
|
{
|
|
//--- check result
|
|
for(int j=0; j<ArraySize(model_output_classes_id); j++)
|
|
{
|
|
//--- check result
|
|
if(model_output_classes_id[j]==correct_classes_batch3[j])
|
|
correct_results++;
|
|
else
|
|
{
|
|
PrintFormat("model:%s FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f)",model_name,model_output_classes_id[j],correct_classes_batch3[j],input_data_batch3[j][0],input_data_batch3[j][1],input_data_batch3[j][2],input_data_batch3[j][3]);
|
|
}
|
|
total_results++;
|
|
}
|
|
}
|
|
else
|
|
return(false);
|
|
|
|
//--- run batch with 10 samples
|
|
float input_data_batch10[10][4]=
|
|
{
|
|
{5.5f,3.5f,1.3f,0.2f}, // iris dataset sample id=37 (Iris-setosa)
|
|
{4.9f,3.1f,1.5f,0.1f}, // iris dataset sample id=38 (Iris-setosa)
|
|
{4.4f,3.0f,1.3f,0.2f}, // iris dataset sample id=39 (Iris-setosa)
|
|
{5.0f,3.3f,1.4f,0.2f}, // iris dataset sample id=50 (Iris-setosa)
|
|
{7.0f,3.2f,4.7f,1.4f}, // iris dataset sample id=51 (Iris-versicolor)
|
|
{6.4f,3.2f,4.5f,1.5f}, // iris dataset sample id=52 (Iris-versicolor)
|
|
{6.3f,3.3f,6.0f,2.5f}, // iris dataset sample id=101 (Iris-virginica)
|
|
{5.8f,2.7f,5.1f,1.9f}, // iris dataset sample id=102 (Iris-virginica)
|
|
{7.1f,3.0f,5.9f,2.1f}, // iris dataset sample id=103 (Iris-virginica)
|
|
{6.3f,2.9f,5.6f,1.8f} // iris dataset sample id=104 (Iris-virginica)
|
|
};
|
|
//--- correct classes for all 10 samples in the batch
|
|
int correct_classes_batch10[10]= {0,0,0,0,1,1,2,2,2,2};
|
|
|
|
//--- run model
|
|
res=TestSamples(model,input_data_batch10,model_output_classes_id);
|
|
//--- check result
|
|
if(res)
|
|
{
|
|
for(int j=0; j<ArraySize(model_output_classes_id); j++)
|
|
{
|
|
if(model_output_classes_id[j]==correct_classes_batch10[j])
|
|
correct_results++;
|
|
else
|
|
{
|
|
double f1=input_data_batch10[j][0];
|
|
double f2=input_data_batch10[j][1];
|
|
double f3=input_data_batch10[j][2];
|
|
double f4=input_data_batch10[j][3];
|
|
PrintFormat("model:%s FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f)",model_name,model_output_classes_id[j],correct_classes_batch10[j],input_data_batch10[j][0],input_data_batch10[j][1],input_data_batch10[j][2],input_data_batch10[j][3]);
|
|
}
|
|
total_results++;
|
|
}
|
|
}
|
|
else
|
|
return(false);
|
|
|
|
//--- calculate accuracy
|
|
model_accuracy=correct_results/total_results;
|
|
//---
|
|
return(res);
|
|
}
|
|
//+------------------------------------------------------------------+
|
|
//| Script program start function |
|
|
//+------------------------------------------------------------------+
|
|
int OnStart(void)
|
|
{
|
|
string model_name="NuSVC";
|
|
//---
|
|
long model=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);
|
|
if(model==INVALID_HANDLE)
|
|
{
|
|
PrintFormat("model_name=%s OnnxCreate error %d for",model_name,GetLastError());
|
|
}
|
|
else
|
|
{
|
|
//--- test all dataset
|
|
double model_accuracy=0;
|
|
//-- test sample by sample execution for all Iris dataset
|
|
if(TestAllIrisDataset(model,model_name,model_accuracy))
|
|
PrintFormat("model=%s all samples accuracy=%f",model_name,model_accuracy);
|
|
else
|
|
PrintFormat("error in testing model=%s ",model_name);
|
|
//--- test batch execution for several samples
|
|
if(TestBatchExecution(model,model_name,model_accuracy))
|
|
PrintFormat("model=%s batch test accuracy=%f",model_name,model_accuracy);
|
|
else
|
|
PrintFormat("error in testing model=%s ",model_name);
|
|
//--- release model
|
|
OnnxRelease(model);
|
|
}
|
|
return(0);
|
|
}
|
|
//+------------------------------------------------------------------+
|