137 lines
5.7 KiB
Python
137 lines
5.7 KiB
Python
import argparse
|
|
import datetime as dt
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import pickle
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
import xgboost as xgb
|
|
from onnxmltools.convert import convert_xgboost
|
|
from onnxmltools.convert.common.data_types import FloatTensorType
|
|
|
|
|
|
def _read_json(path: str) -> Dict:
|
|
with open(path, "r", encoding="utf-8") as handle:
|
|
return json.load(handle)
|
|
|
|
|
|
def _load_scaler(path: str):
|
|
with open(path, "rb") as handle:
|
|
scaler = pickle.load(handle)
|
|
if not hasattr(scaler, "mean_") or not hasattr(scaler, "scale_"):
|
|
raise ValueError("Scaler pickle is missing mean_/scale_ attributes")
|
|
return scaler
|
|
|
|
|
|
def _write_json(dest: str, payload: Dict) -> None:
|
|
with open(dest, "w", encoding="utf-8") as handle:
|
|
json.dump(payload, handle, indent=2)
|
|
|
|
|
|
def _format_list(values) -> str:
|
|
return "|".join(f"{float(v):.12g}" for v in values)
|
|
|
|
|
|
def _write_ini(dest: str, meta: Dict, scaler, input_name: str, output_name: str) -> None:
|
|
feature_names = meta.get("features", [])
|
|
lines = [
|
|
"# DualEA ONNX Runtime configuration",
|
|
f"created={dt.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ')}",
|
|
"version=1",
|
|
f"feature_count={len(feature_names)}",
|
|
f"feature_names={'|'.join(feature_names)}",
|
|
f"scaler_mean={_format_list(getattr(scaler, 'mean_', np.zeros(len(feature_names))))}",
|
|
f"scaler_scale={_format_list(getattr(scaler, 'scale_', np.ones(len(feature_names))))}",
|
|
f"input_name={input_name}",
|
|
f"output_name={output_name}",
|
|
f"label_mode={meta.get('label_mode', 'status')}",
|
|
f"positive_status={'|'.join(meta.get('positive_status', []))}",
|
|
]
|
|
categorical = meta.get("categorical_mappings", {})
|
|
for key, mapping in categorical.items():
|
|
if isinstance(mapping, dict) and mapping:
|
|
encoded = "|".join(f"{str(k)}:{int(v)}" for k, v in mapping.items())
|
|
lines.append(f"cat_{key}={encoded}")
|
|
with open(dest, "w", encoding="utf-8") as handle:
|
|
handle.write("\n".join(lines))
|
|
|
|
|
|
def convert_to_onnx(model: xgb.XGBClassifier, feature_count: int, opset: int):
|
|
booster = model.get_booster()
|
|
initial_type = [("input", FloatTensorType([None, feature_count]))]
|
|
onnx_model = convert_xgboost(booster, initial_types=initial_type, target_opset=opset)
|
|
|
|
input_name = onnx_model.graph.input[0].name if onnx_model.graph.input else "input"
|
|
output_name = onnx_model.graph.output[0].name if onnx_model.graph.output else "output"
|
|
return (onnx_model.SerializeToString(), input_name, output_name)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Convert trained XGBoost snapshot model to ONNX")
|
|
parser.add_argument("--artifacts", type=str, default=os.path.join(os.getcwd(), "artifacts"))
|
|
parser.add_argument("--model_json", type=str, default=None, help="Override path to xgb_model.json")
|
|
parser.add_argument("--feature_meta", type=str, default=None, help="Override path to feature_meta.json")
|
|
parser.add_argument("--scaler", type=str, default=None, help="Override path to scaler.pkl")
|
|
parser.add_argument("--opset", type=int, default=15)
|
|
parser.add_argument("--onnx_out", type=str, default=None, help="Destination ONNX path")
|
|
parser.add_argument("--config_json", type=str, default=None, help="Destination JSON metadata path")
|
|
parser.add_argument("--config_ini", type=str, default=None, help="Destination INI metadata path")
|
|
args = parser.parse_args()
|
|
|
|
artifacts = os.path.abspath(args.artifacts)
|
|
model_path = args.model_json or os.path.join(artifacts, "xgb_model.json")
|
|
feature_meta_path = args.feature_meta or os.path.join(artifacts, "feature_meta.json")
|
|
scaler_path = args.scaler or os.path.join(artifacts, "scaler.pkl")
|
|
onnx_path = args.onnx_out or os.path.join(artifacts, "signal_model.onnx")
|
|
config_json_path = args.config_json or os.path.join(artifacts, "onnx_config.json")
|
|
config_ini_path = args.config_ini or os.path.join(artifacts, "onnx_config.ini")
|
|
|
|
os.makedirs(artifacts, exist_ok=True)
|
|
|
|
meta = _read_json(feature_meta_path)
|
|
feature_names = meta.get("features")
|
|
if not feature_names:
|
|
raise ValueError("feature_meta.json missing 'features'")
|
|
feature_count = len(feature_names)
|
|
|
|
scaler = _load_scaler(scaler_path)
|
|
|
|
model = xgb.XGBClassifier()
|
|
model.load_model(model_path)
|
|
|
|
onnx_bytes, input_name, output_name = convert_to_onnx(model, feature_count, args.opset)
|
|
with open(onnx_path, "wb") as handle:
|
|
handle.write(onnx_bytes)
|
|
|
|
payload = {
|
|
"version": 1,
|
|
"created": dt.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
"onnx_model": {
|
|
"path": onnx_path,
|
|
"sha256": hashlib.sha256(onnx_bytes).hexdigest(),
|
|
"input_name": input_name,
|
|
"output_name": output_name,
|
|
"opset": args.opset,
|
|
},
|
|
"features": feature_names,
|
|
"categorical_mappings": meta.get("categorical_mappings", {}),
|
|
"label_mode": meta.get("label_mode", "status"),
|
|
"positive_status": meta.get("positive_status", []),
|
|
"scaler": {
|
|
"mean": getattr(scaler, "mean_", np.zeros(feature_count)).tolist(),
|
|
"scale": getattr(scaler, "scale_", np.ones(feature_count)).tolist(),
|
|
},
|
|
}
|
|
|
|
_write_json(config_json_path, payload)
|
|
_write_ini(config_ini_path, meta, scaler, input_name, output_name)
|
|
|
|
print(f"[OK] ONNX model written to {onnx_path}")
|
|
print(f"[OK] Config JSON written to {config_json_path}")
|
|
print(f"[OK] Config INI written to {config_ini_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|