forked from nique_372/AiDataGenByLeo
72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
# Copyright 2026, Niquel Mendoza | Leo.
|
|
# https://www.mql5.com/es/users/nique_372
|
|
# trainer_regression.py
|
|
|
|
import onnx
|
|
from onnx import helper
|
|
|
|
# Funcion helper
|
|
def fix_onnx_output_shape(input_path : str, output_path : str) -> None:
|
|
"""
|
|
Modifica un modelo ONNX de regresión para cambiar output de [1] a [1,1]
|
|
"""
|
|
print(f"Cargando modelo: {input_path}")
|
|
model = onnx.load(input_path)
|
|
|
|
# Obtener output actual
|
|
old_output = model.graph.output[0]
|
|
print(f"Output original: {old_output.name}")
|
|
print(f"Shape original: {[d.dim_value for d in old_output.type.tensor_type.shape.dim]}")
|
|
|
|
# Crear nodo Reshape
|
|
reshape_node = helper.make_node(
|
|
'Reshape',
|
|
inputs=[old_output.name, 'reshape_shape'],
|
|
outputs=['output_reshaped'],
|
|
name='fix_output_shape'
|
|
)
|
|
|
|
# Crear tensor con shape [1, 1]
|
|
shape_tensor = helper.make_tensor(
|
|
name='reshape_shape',
|
|
data_type=onnx.TensorProto.INT64,
|
|
dims=[2],
|
|
vals=[1, 1]
|
|
)
|
|
|
|
# Agregar al grafo
|
|
model.graph.node.append(reshape_node)
|
|
model.graph.initializer.append(shape_tensor)
|
|
|
|
# IMPORTANTE: Agregar también como input del grafo (requerido por ONNX)
|
|
shape_input = helper.make_tensor_value_info(
|
|
'reshape_shape',
|
|
onnx.TensorProto.INT64,
|
|
[2]
|
|
)
|
|
model.graph.input.append(shape_input)
|
|
|
|
# Crear nuevo output [1, 1]
|
|
new_output = helper.make_tensor_value_info(
|
|
'output_reshaped',
|
|
onnx.TensorProto.FLOAT,
|
|
[1, 1]
|
|
)
|
|
|
|
# Reemplazar output
|
|
model.graph.output.remove(old_output)
|
|
model.graph.output.append(new_output)
|
|
|
|
# Guardar
|
|
onnx.save(model, output_path)
|
|
print(f"Modelo guardado: {output_path}")
|
|
|
|
# Verificar (sin checker estricto)
|
|
model_check = onnx.load(output_path)
|
|
try:
|
|
onnx.checker.check_model(model_check)
|
|
except Exception as e:
|
|
print(f" Exepcion al chekear modelo {e}")
|
|
|
|
new_shape = [d.dim_value for d in model_check.graph.output[0].type.tensor_type.shape.dim]
|
|
print(f" Shape final: {new_shape}")
|