IchiGridEA_ONNX_SourceBank/04_DEEP_REVIEW/extract_onnx_shapes_20260527.py

88 lines
2.9 KiB
Python

import csv
import hashlib
from pathlib import Path
import onnx
repo = Path(r"F:\IchigridEA\_RESTORE_CONTROL_PLANE_V2\70_FORGE_MQL5_SOURCEBANK\repo_personnel_update\IchiGridEA_ONNX_SourceBank")
onnx_root = repo / "02_UPSTREAM_REFERENCES" / "renat_ONNX.Price.Prediction"
out_csv = repo / "04_DEEP_REVIEW" / "ONNX_MODEL_SHAPES_20260527.csv"
def sha256_file(path: Path) -> str:
h = hashlib.sha256()
with path.open("rb") as f:
for block in iter(lambda: f.read(1024 * 1024), b""):
h.update(block)
return h.hexdigest()
def dims_of(value_info):
dims = []
tensor_type = value_info.type.tensor_type
for d in tensor_type.shape.dim:
if d.dim_value:
dims.append(str(d.dim_value))
elif d.dim_param:
dims.append(str(d.dim_param))
else:
dims.append("?")
return "[" + ",".join(dims) + "]"
rows = []
for model_path in sorted(onnx_root.rglob("*.onnx")):
try:
model = onnx.load(str(model_path))
onnx.checker.check_model(model)
status = "ONNX_CHECK_PASS"
error = ""
except Exception as exc:
model = None
status = "ONNX_CHECK_FAIL"
error = str(exc)
if model is not None:
for inp in model.graph.input:
rows.append({
"ModelPath": str(model_path.relative_to(repo)).replace("\\", "/"),
"Kind": "INPUT",
"Name": inp.name,
"Shape": dims_of(inp),
"Opset": ",".join([str(o.version) for o in model.opset_import]),
"IRVersion": model.ir_version,
"SHA256": sha256_file(model_path),
"Status": status,
"Error": error
})
for out in model.graph.output:
rows.append({
"ModelPath": str(model_path.relative_to(repo)).replace("\\", "/"),
"Kind": "OUTPUT",
"Name": out.name,
"Shape": dims_of(out),
"Opset": ",".join([str(o.version) for o in model.opset_import]),
"IRVersion": model.ir_version,
"SHA256": sha256_file(model_path),
"Status": status,
"Error": error
})
else:
rows.append({
"ModelPath": str(model_path.relative_to(repo)).replace("\\", "/"),
"Kind": "MODEL",
"Name": model_path.name,
"Shape": "",
"Opset": "",
"IRVersion": "",
"SHA256": sha256_file(model_path),
"Status": status,
"Error": error
})
with out_csv.open("w", encoding="utf-8-sig", newline="") as f:
fieldnames = ["ModelPath", "Kind", "Name", "Shape", "Opset", "IRVersion", "SHA256", "Status", "Error"]
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
print(f"ONNX shape rows: {len(rows)}")
print(out_csv)