80 lines
3 KiB
Python
80 lines
3 KiB
Python
|
# Iris_ExtraTreesClassifier.py
|
||
|
# The code demonstrates the process of training ExtraTrees Classifier model on the Iris dataset, exporting it to ONNX format, and making predictions using the ONNX model.
|
||
|
# It also evaluates the accuracy of both the original model and the ONNX model.
|
||
|
# Copyright 2023, MetaQuotes Ltd.
|
||
|
# https://www.mql5.com
|
||
|
|
||
|
# import necessary libraries
|
||
|
from sklearn import datasets
|
||
|
from sklearn.ensemble import ExtraTreesClassifier
|
||
|
from sklearn.metrics import accuracy_score, classification_report
|
||
|
from skl2onnx import convert_sklearn
|
||
|
from skl2onnx.common.data_types import FloatTensorType
|
||
|
import onnxruntime as ort
|
||
|
import numpy as np
|
||
|
from sys import argv
|
||
|
|
||
|
# define the path for saving the model
|
||
|
data_path = argv[0]
|
||
|
last_index = data_path.rfind("\\") + 1
|
||
|
data_path = data_path[0:last_index]
|
||
|
|
||
|
# load the Iris dataset
|
||
|
iris = datasets.load_iris()
|
||
|
X = iris.data
|
||
|
y = iris.target
|
||
|
|
||
|
# create an ExtraTreesClassifier model
|
||
|
extra_trees_model = ExtraTreesClassifier()
|
||
|
|
||
|
# train the model on the entire dataset
|
||
|
extra_trees_model.fit(X, y)
|
||
|
|
||
|
# predict classes for the entire dataset
|
||
|
y_pred = extra_trees_model.predict(X)
|
||
|
|
||
|
# evaluate the model's accuracy
|
||
|
accuracy = accuracy_score(y, y_pred)
|
||
|
print("Accuracy of ExtraTreesClassifier model:", accuracy)
|
||
|
|
||
|
# display the classification report
|
||
|
print("\nClassification Report:\n", classification_report(y, y_pred))
|
||
|
|
||
|
# define the input data type
|
||
|
initial_type = [('float_input', FloatTensorType([None, X.shape[1]]))]
|
||
|
|
||
|
# export the model to ONNX format with float data type
|
||
|
onnx_model = convert_sklearn(extra_trees_model, initial_types=initial_type, target_opset=12)
|
||
|
|
||
|
# save the model to a file
|
||
|
onnx_filename = data_path + "..\\models\\extra_trees_iris.onnx"
|
||
|
with open(onnx_filename, "wb") as f:
|
||
|
f.write(onnx_model.SerializeToString())
|
||
|
|
||
|
# print model path
|
||
|
print(f"Model saved to {onnx_filename}")
|
||
|
|
||
|
# load the ONNX model and make predictions
|
||
|
onnx_session = ort.InferenceSession(onnx_filename)
|
||
|
input_name = onnx_session.get_inputs()[0].name
|
||
|
output_name = onnx_session.get_outputs()[0].name
|
||
|
|
||
|
# display information about input tensors in ONNX
|
||
|
print("\nInformation about input tensors in ONNX:")
|
||
|
for i, input_tensor in enumerate(onnx_session.get_inputs()):
|
||
|
print(f"{i + 1}. Name: {input_tensor.name}, Data Type: {input_tensor.type}, Shape: {input_tensor.shape}")
|
||
|
|
||
|
# display information about output tensors in ONNX
|
||
|
print("\nInformation about output tensors in ONNX:")
|
||
|
for i, output_tensor in enumerate(onnx_session.get_outputs()):
|
||
|
print(f"{i + 1}. Name: {output_tensor.name}, Data Type: {output_tensor.type}, Shape: {output_tensor.shape}")
|
||
|
|
||
|
# convert data to floating-point format (float32)
|
||
|
X_float32 = X.astype(np.float32)
|
||
|
|
||
|
# predict classes for the entire dataset using ONNX
|
||
|
y_pred_onnx = onnx_session.run([output_name], {input_name: X_float32})[0]
|
||
|
|
||
|
# evaluate the accuracy of the ONNX model
|
||
|
accuracy_onnx = accuracy_score(y, y_pred_onnx)
|
||
|
print("\nAccuracy of ExtraTreesClassifier model in ONNX format:", accuracy_onnx)
|