266 lines
21 KiB
MQL5
266 lines
21 KiB
MQL5
//+------------------------------------------------------------------+
|
|
//| Iris_KNearestNeighborsClassifier.mq5 |
|
|
//| Copyright 2023, MetaQuotes Ltd. |
|
|
//| https://www.mql5.com |
|
|
//+------------------------------------------------------------------+
|
|
#property copyright "Copyright 2023, MetaQuotes Ltd."
|
|
#property link "https://www.mql5.com"
|
|
#property version "1.00"
|
|
|
|
#include "iris.mqh"
|
|
#resource "models\\knn_iris.onnx" as const uchar ExtModel[];
|
|
|
|
//+------------------------------------------------------------------+
|
|
//| Test IRIS dataset samples |
|
|
//+------------------------------------------------------------------+
|
|
bool TestSamples(long model,float &input_data[][4], int &model_classes_id[])
|
|
{
|
|
//--- check number of input samples
|
|
ulong batch_size=input_data.Range(0);
|
|
if(batch_size==0)
|
|
return(false);
|
|
//--- prepare output array
|
|
ArrayResize(model_classes_id,(int)batch_size);
|
|
//---
|
|
float output_data[];
|
|
//---
|
|
struct Map
|
|
{
|
|
ulong key[];
|
|
float value[];
|
|
} output_data_map[];
|
|
//--- check consistency
|
|
bool res=ArrayResize(output_data,(int)batch_size)==batch_size;
|
|
//---
|
|
if(res)
|
|
{
|
|
//--- set input shape
|
|
ulong input_shape[]= {batch_size,input_data.Range(1)};
|
|
OnnxSetInputShape(model,0,input_shape);
|
|
//--- set output shapeы
|
|
ulong output_shape1[]= {batch_size};
|
|
ulong output_shape2[]= {batch_size};
|
|
OnnxSetOutputShape(model,0,output_shape1);
|
|
OnnxSetOutputShape(model,1,output_shape2);
|
|
//--- run the model
|
|
res=OnnxRun(model,0,input_data,output_data,output_data_map);
|
|
//--- postprocessing
|
|
if(res)
|
|
{
|
|
//--- postprocessing of sequence map data
|
|
//--- find class with maximum probability
|
|
ulong output_keys[];
|
|
float output_values[];
|
|
//---
|
|
for(uint n=0; n<output_data_map.Size(); n++)
|
|
{
|
|
int model_class_id=-1;
|
|
int max_idx=-1;
|
|
float max_value=-1;
|
|
//--- copy to arrays
|
|
ArrayCopy(output_keys,output_data_map[n].key);
|
|
ArrayCopy(output_values,output_data_map[n].value);
|
|
//ArrayPrint(output_keys);
|
|
//ArrayPrint(output_values);
|
|
//--- find the key with maximum probability
|
|
for(int k=0; k<ArraySize(output_values); k++)
|
|
{
|
|
if(k==0)
|
|
{
|
|
max_idx=0;
|
|
max_value=output_values[max_idx];
|
|
model_class_id=(int)output_keys[max_idx];
|
|
}
|
|
else
|
|
{
|
|
if(output_values[k]>max_value)
|
|
{
|
|
max_idx=k;
|
|
max_value=output_values[max_idx];
|
|
model_class_id=(int)output_keys[max_idx];
|
|
}
|
|
}
|
|
}
|
|
//--- store the result to the output array
|
|
model_classes_id[n]=model_class_id;
|
|
//Print("model_class_id=",model_class_id);
|
|
}
|
|
}
|
|
}
|
|
//---
|
|
return(res);
|
|
}
|
|
|
|
//+------------------------------------------------------------------+
|
|
//| Test all samples from IRIS dataset (150) |
|
|
//| Here we test all samples with batch=1, sample by sample |
|
|
//+------------------------------------------------------------------+
|
|
bool TestAllIrisDataset(const long model,const string model_name,double &model_accuracy)
|
|
{
|
|
sIRISsample iris_samples[];
|
|
//--- load dataset from file
|
|
PrepareIrisDataset(iris_samples);
|
|
//--- test
|
|
int total_samples=ArraySize(iris_samples);
|
|
if(total_samples==0)
|
|
{
|
|
Print("iris dataset not prepared");
|
|
return(false);
|
|
}
|
|
//--- show dataset
|
|
for(int k=0; k<total_samples; k++)
|
|
{
|
|
//PrintFormat("%d (%.2f,%.2f,%.2f,%.2f) class %d (%s)",iris_samples[k].sample_id,iris_samples[k].features[0],iris_samples[k].features[1],iris_samples[k].features[2],iris_samples[k].features[3],iris_samples[k].class_id,iris_samples[k].class_name);
|
|
}
|
|
//--- array for output classes
|
|
int model_output_classes_id[];
|
|
//--- check all Iris dataset samples
|
|
int correct_results=0;
|
|
for(int k=0; k<total_samples; k++)
|
|
{
|
|
//--- input array
|
|
float iris_sample_input_data[1][4];
|
|
//--- prepare input data from kth iris sample dataset
|
|
iris_sample_input_data[0][0]=(float)iris_samples[k].features[0];
|
|
iris_sample_input_data[0][1]=(float)iris_samples[k].features[1];
|
|
iris_sample_input_data[0][2]=(float)iris_samples[k].features[2];
|
|
iris_sample_input_data[0][3]=(float)iris_samples[k].features[3];
|
|
//--- run model
|
|
bool res=TestSamples(model,iris_sample_input_data,model_output_classes_id);
|
|
//--- check result
|
|
if(res)
|
|
{
|
|
if(model_output_classes_id[0]==iris_samples[k].class_id)
|
|
{
|
|
correct_results++;
|
|
}
|
|
else
|
|
{
|
|
PrintFormat("model:%s sample=%d FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f]",model_name,iris_samples[k].sample_id,model_output_classes_id[0],iris_samples[k].class_id,iris_samples[k].features[0],iris_samples[k].features[1],iris_samples[k].features[2],iris_samples[k].features[3]);
|
|
}
|
|
}
|
|
}
|
|
model_accuracy=1.0*correct_results/total_samples;
|
|
//---
|
|
PrintFormat("model:%s correct results: %.2f%%",model_name,100*model_accuracy);
|
|
//---
|
|
return(true);
|
|
}
|
|
|
|
//+------------------------------------------------------------------+
|
|
//| Here we test batch execution of the model |
|
|
//+------------------------------------------------------------------+
|
|
bool TestBatchExecution(const long model,const string model_name,double &model_accuracy)
|
|
{
|
|
model_accuracy=0;
|
|
//--- array for output classes
|
|
int model_output_classes_id[];
|
|
int correct_results=0;
|
|
int total_results=0;
|
|
bool res=false;
|
|
|
|
//--- run batch with 3 samples
|
|
float input_data_batch3[3][4]=
|
|
{
|
|
{5.1f,3.5f,1.4f,0.2f}, // iris dataset sample id=1, Iris-setosa
|
|
{6.3f,2.5f,4.9f,1.5f}, // iris dataset sample id=73, Iris-versicolor
|
|
{6.3f,2.7f,4.9f,1.8f} // iris dataset sample id=124, Iris-virginica
|
|
};
|
|
int correct_classes_batch3[3]= {0,1,2};
|
|
//--- run model
|
|
res=TestSamples(model,input_data_batch3,model_output_classes_id);
|
|
if(res)
|
|
{
|
|
//--- check result
|
|
for(int j=0; j<ArraySize(model_output_classes_id); j++)
|
|
{
|
|
//--- check result
|
|
if(model_output_classes_id[j]==correct_classes_batch3[j])
|
|
correct_results++;
|
|
else
|
|
{
|
|
PrintFormat("model:%s FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f)",model_name,model_output_classes_id[j],correct_classes_batch3[j],input_data_batch3[j][0],input_data_batch3[j][1],input_data_batch3[j][2],input_data_batch3[j][3]);
|
|
}
|
|
total_results++;
|
|
}
|
|
}
|
|
else
|
|
return(false);
|
|
|
|
//--- run batch with 10 samples
|
|
float input_data_batch10[10][4]=
|
|
{
|
|
{5.5f,3.5f,1.3f,0.2f}, // iris dataset sample id=37 (Iris-setosa)
|
|
{4.9f,3.1f,1.5f,0.1f}, // iris dataset sample id=38 (Iris-setosa)
|
|
{4.4f,3.0f,1.3f,0.2f}, // iris dataset sample id=39 (Iris-setosa)
|
|
{5.0f,3.3f,1.4f,0.2f}, // iris dataset sample id=50 (Iris-setosa)
|
|
{7.0f,3.2f,4.7f,1.4f}, // iris dataset sample id=51 (Iris-versicolor)
|
|
{6.4f,3.2f,4.5f,1.5f}, // iris dataset sample id=52 (Iris-versicolor)
|
|
{6.3f,3.3f,6.0f,2.5f}, // iris dataset sample id=101 (Iris-virginica)
|
|
{5.8f,2.7f,5.1f,1.9f}, // iris dataset sample id=102 (Iris-virginica)
|
|
{7.1f,3.0f,5.9f,2.1f}, // iris dataset sample id=103 (Iris-virginica)
|
|
{6.3f,2.9f,5.6f,1.8f} // iris dataset sample id=104 (Iris-virginica)
|
|
};
|
|
//--- correct classes for all 10 samples in the batch
|
|
int correct_classes_batch10[10]= {0,0,0,0,1,1,2,2,2,2};
|
|
|
|
//--- run model
|
|
res=TestSamples(model,input_data_batch10,model_output_classes_id);
|
|
//--- check result
|
|
if(res)
|
|
{
|
|
for(int j=0; j<ArraySize(model_output_classes_id); j++)
|
|
{
|
|
if(model_output_classes_id[j]==correct_classes_batch10[j])
|
|
correct_results++;
|
|
else
|
|
{
|
|
double f1=input_data_batch10[j][0];
|
|
double f2=input_data_batch10[j][1];
|
|
double f3=input_data_batch10[j][2];
|
|
double f4=input_data_batch10[j][3];
|
|
PrintFormat("model:%s FAILED [class=%d, true class=%d] features=(%.2f,%.2f,%.2f,%.2f)",model_name,model_output_classes_id[j],correct_classes_batch10[j],input_data_batch10[j][0],input_data_batch10[j][1],input_data_batch10[j][2],input_data_batch10[j][3]);
|
|
}
|
|
total_results++;
|
|
}
|
|
}
|
|
else
|
|
return(false);
|
|
|
|
//--- calculate accuracy
|
|
model_accuracy=correct_results/total_results;
|
|
//---
|
|
return(res);
|
|
}
|
|
//+------------------------------------------------------------------+
|
|
//| Script program start function |
|
|
//+------------------------------------------------------------------+
|
|
int OnStart(void)
|
|
{
|
|
string model_name="KNearestNeighborsClassifier";
|
|
//---
|
|
long model=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);
|
|
if(model==INVALID_HANDLE)
|
|
{
|
|
PrintFormat("model_name=%s OnnxCreate error %d for",model_name,GetLastError());
|
|
}
|
|
else
|
|
{
|
|
//--- test all dataset
|
|
double model_accuracy=0;
|
|
//-- test sample by sample execution for all Iris dataset
|
|
if(TestAllIrisDataset(model,model_name,model_accuracy))
|
|
PrintFormat("model=%s all samples accuracy=%f",model_name,model_accuracy);
|
|
else
|
|
PrintFormat("error in testing model=%s ",model_name);
|
|
//--- test batch execution for several samples
|
|
if(TestBatchExecution(model,model_name,model_accuracy))
|
|
PrintFormat("model=%s batch test accuracy=%f",model_name,model_accuracy);
|
|
else
|
|
PrintFormat("error in testing model=%s ",model_name);
|
|
//--- release model
|
|
OnnxRelease(model);
|
|
}
|
|
return(0);
|
|
}
|
|
//+------------------------------------------------------------------+
|