ONNX.Price.Prediction/Python/PricePrediction.py

47 lines
1.2 KiB
Python
Raw Permalink Normal View History

2024-08-08 23:31:54 +02:00
# Copyright 2023, MetaQuotes Ltd.
# https://www.mql5.com
import MetaTrader5 as mt5
import numpy as np
import onnxruntime as ort
import pandas as pd
from sys import argv
if not mt5.initialize():
print("initialize() failed, error code =",mt5.last_error())
quit()
# you code here
#
data_path=argv[0]
last_index=data_path.rfind("\\")+1
data_path=data_path[0:last_index]
print("data path to load onnx model",data_path)
eurusd_rates = mt5.copy_rates_from_pos("EURUSD", mt5.TIMEFRAME_H1, 1, 10)
#print(eurusd_rates)
df = pd.DataFrame(eurusd_rates)
X = df[['open', 'high', 'low', 'close']].values
X = np.expand_dims(X, axis=0)
#print(X)
m = X.mean(axis=1, keepdims=True)
#print(m)
s = X.std(axis=1, keepdims=True)
#print(s)
X_norm = (X - m) / s
#print(X_norm)
model_path = data_path+"model.onnx"
ort_sess = ort.InferenceSession(model_path)
output_names = [out.name for out in ort_sess.get_outputs()]
outputs = ort_sess.run(output_names, {'lstm_input': X_norm.astype(np.float32)})
y_pred_norm = outputs[0]
print("raw output:",y_pred_norm)
y_pred = np.round(y_pred_norm.flatten() * s[:, 0, 3] + m[:, 0, 3], decimals=5)
print("predicted:",y_pred)
mt5.shutdown()