88 lines
2.9 KiB
Python
88 lines
2.9 KiB
Python
|
|
import csv
|
||
|
|
import hashlib
|
||
|
|
from pathlib import Path
|
||
|
|
import onnx
|
||
|
|
|
||
|
|
repo = Path(r"F:\IchigridEA\_RESTORE_CONTROL_PLANE_V2\70_FORGE_MQL5_SOURCEBANK\repo_personnel_update\IchiGridEA_ONNX_SourceBank")
|
||
|
|
onnx_root = repo / "02_UPSTREAM_REFERENCES" / "renat_ONNX.Price.Prediction"
|
||
|
|
out_csv = repo / "04_DEEP_REVIEW" / "ONNX_MODEL_SHAPES_20260527.csv"
|
||
|
|
|
||
|
|
def sha256_file(path: Path) -> str:
|
||
|
|
h = hashlib.sha256()
|
||
|
|
with path.open("rb") as f:
|
||
|
|
for block in iter(lambda: f.read(1024 * 1024), b""):
|
||
|
|
h.update(block)
|
||
|
|
return h.hexdigest()
|
||
|
|
|
||
|
|
def dims_of(value_info):
|
||
|
|
dims = []
|
||
|
|
tensor_type = value_info.type.tensor_type
|
||
|
|
for d in tensor_type.shape.dim:
|
||
|
|
if d.dim_value:
|
||
|
|
dims.append(str(d.dim_value))
|
||
|
|
elif d.dim_param:
|
||
|
|
dims.append(str(d.dim_param))
|
||
|
|
else:
|
||
|
|
dims.append("?")
|
||
|
|
return "[" + ",".join(dims) + "]"
|
||
|
|
|
||
|
|
rows = []
|
||
|
|
|
||
|
|
for model_path in sorted(onnx_root.rglob("*.onnx")):
|
||
|
|
try:
|
||
|
|
model = onnx.load(str(model_path))
|
||
|
|
onnx.checker.check_model(model)
|
||
|
|
status = "ONNX_CHECK_PASS"
|
||
|
|
error = ""
|
||
|
|
except Exception as exc:
|
||
|
|
model = None
|
||
|
|
status = "ONNX_CHECK_FAIL"
|
||
|
|
error = str(exc)
|
||
|
|
|
||
|
|
if model is not None:
|
||
|
|
for inp in model.graph.input:
|
||
|
|
rows.append({
|
||
|
|
"ModelPath": str(model_path.relative_to(repo)).replace("\\", "/"),
|
||
|
|
"Kind": "INPUT",
|
||
|
|
"Name": inp.name,
|
||
|
|
"Shape": dims_of(inp),
|
||
|
|
"Opset": ",".join([str(o.version) for o in model.opset_import]),
|
||
|
|
"IRVersion": model.ir_version,
|
||
|
|
"SHA256": sha256_file(model_path),
|
||
|
|
"Status": status,
|
||
|
|
"Error": error
|
||
|
|
})
|
||
|
|
|
||
|
|
for out in model.graph.output:
|
||
|
|
rows.append({
|
||
|
|
"ModelPath": str(model_path.relative_to(repo)).replace("\\", "/"),
|
||
|
|
"Kind": "OUTPUT",
|
||
|
|
"Name": out.name,
|
||
|
|
"Shape": dims_of(out),
|
||
|
|
"Opset": ",".join([str(o.version) for o in model.opset_import]),
|
||
|
|
"IRVersion": model.ir_version,
|
||
|
|
"SHA256": sha256_file(model_path),
|
||
|
|
"Status": status,
|
||
|
|
"Error": error
|
||
|
|
})
|
||
|
|
else:
|
||
|
|
rows.append({
|
||
|
|
"ModelPath": str(model_path.relative_to(repo)).replace("\\", "/"),
|
||
|
|
"Kind": "MODEL",
|
||
|
|
"Name": model_path.name,
|
||
|
|
"Shape": "",
|
||
|
|
"Opset": "",
|
||
|
|
"IRVersion": "",
|
||
|
|
"SHA256": sha256_file(model_path),
|
||
|
|
"Status": status,
|
||
|
|
"Error": error
|
||
|
|
})
|
||
|
|
|
||
|
|
with out_csv.open("w", encoding="utf-8-sig", newline="") as f:
|
||
|
|
fieldnames = ["ModelPath", "Kind", "Name", "Shape", "Opset", "IRVersion", "SHA256", "Status", "Error"]
|
||
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||
|
|
writer.writeheader()
|
||
|
|
writer.writerows(rows)
|
||
|
|
|
||
|
|
print(f"ONNX shape rows: {len(rows)}")
|
||
|
|
print(out_csv)
|