AiDataGenByLeo/Py/out_fix.py
Nique_372 4a046263f7
2026-03-19 15:59:52 -05:00

72 lines
2.1 KiB
Python

# Copyright 2026, Niquel Mendoza | Leo.
# https://www.mql5.com/es/users/nique_372
# trainer_regression.py
import onnx
from onnx import helper
# Funcion helper
def fix_onnx_output_shape(input_path : str, output_path : str) -> None:
"""
Modifica un modelo ONNX de regresión para cambiar output de [1] a [1,1]
"""
print(f"Cargando modelo: {input_path}")
model = onnx.load(input_path)
# Obtener output actual
old_output = model.graph.output[0]
print(f"Output original: {old_output.name}")
print(f"Shape original: {[d.dim_value for d in old_output.type.tensor_type.shape.dim]}")
# Crear nodo Reshape
reshape_node = helper.make_node(
'Reshape',
inputs=[old_output.name, 'reshape_shape'],
outputs=['output_reshaped'],
name='fix_output_shape'
)
# Crear tensor con shape [1, 1]
shape_tensor = helper.make_tensor(
name='reshape_shape',
data_type=onnx.TensorProto.INT64,
dims=[2],
vals=[1, 1]
)
# Agregar al grafo
model.graph.node.append(reshape_node)
model.graph.initializer.append(shape_tensor)
# IMPORTANTE: Agregar también como input del grafo (requerido por ONNX)
shape_input = helper.make_tensor_value_info(
'reshape_shape',
onnx.TensorProto.INT64,
[2]
)
model.graph.input.append(shape_input)
# Crear nuevo output [1, 1]
new_output = helper.make_tensor_value_info(
'output_reshaped',
onnx.TensorProto.FLOAT,
[1, 1]
)
# Reemplazar output
model.graph.output.remove(old_output)
model.graph.output.append(new_output)
# Guardar
onnx.save(model, output_path)
print(f"Modelo guardado: {output_path}")
# Verificar (sin checker estricto)
model_check = onnx.load(output_path)
try:
onnx.checker.check_model(model_check)
except Exception as e:
print(f" Exepcion al chekear modelo {e}")
new_shape = [d.dim_value for d in model_check.graph.output[0].type.tensor_type.shape.dim]
print(f" Shape final: {new_shape}")