markaz_arshy/train_xgboost.py
2025-08-12 14:36:24 +00:00

143 lines
No EOL
5 KiB
Python

# train_xgboost.py
#
# Deskripsi:
# Skrip ini melatih model XGBoost untuk setiap simbol berdasarkan data latih
# yang dihasilkan oleh generate_training_data.py.
import json
import logging
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, precision_score, recall_score, f1_score, roc_auc_score
# --- Konfigurasi ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
TRAINING_DATA_PATH = 'trade_feedback.json'
MODEL_OUTPUT_PATH = 'xgboost_model_{symbol}.json'
def load_training_data(filepath):
"""Memuat data latih dari file JSON."""
try:
with open(filepath, 'r') as f:
data = json.load(f)
return data
except (FileNotFoundError, json.JSONDecodeError) as e:
logging.error(f"Error saat memuat data latih: {e}")
return []
def train_model_for_symbol(symbol, trades):
"""Melatih model XGBoost untuk simbol tertentu."""
if not trades:
logging.warning(f"Tidak ada data untuk simbol {symbol}. Melewati pelatihan.")
return None
# Siapkan data dengan validasi
try:
X = np.array([trade['gng_input_features_on_signal'] for trade in trades])
y = np.array([1 if trade['result'].upper() == "WIN" else 0 for trade in trades])
except (KeyError, AttributeError) as e:
logging.error(f"Format data tidak valid untuk {symbol}: {e}")
return None
# Validasi jumlah data dan rasio kelas
if len(X) < 100: # Minimum sampel yang dibutuhkan
logging.warning(f"Data untuk {symbol} terlalu sedikit ({len(X)}). Minimal 100 sampel diperlukan.")
return None
class_ratio = np.mean(y)
if class_ratio < 0.2 or class_ratio > 0.8:
logging.warning(f"Ketidakseimbangan kelas terdeteksi untuk {symbol}. Win rate: {class_ratio:.2%}")
# Split data dengan stratifikasi
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# Load hyperparameters dari hasil tuning jika ada
try:
with open(f'tuning_results_{symbol}.json', 'r') as f:
best_params = json.load(f)
logging.info(f"Menggunakan parameter hasil tuning untuk {symbol}")
except FileNotFoundError:
logging.warning(f"File tuning tidak ditemukan untuk {symbol}, menggunakan parameter default")
best_params = {
'max_depth': 6,
'learning_rate': 0.1,
'n_estimators': 100,
'subsample': 0.8,
'colsample_bytree': 0.8
}
# Inisialisasi dan latih model
model = xgb.XGBClassifier(
objective='binary:logistic',
random_state=42,
**best_params
)
# Early stopping untuk mencegah overfitting
eval_set = [(X_test, y_test)]
model.fit(
X_train, y_train,
eval_set=eval_set,
eval_metric=['logloss', 'error'],
early_stopping_rounds=20,
verbose=True
)
# Evaluasi model secara menyeluruh
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
metrics = {
'accuracy': accuracy_score(y_test, y_pred),
'precision': precision_score(y_test, y_pred),
'recall': recall_score(y_test, y_pred),
'f1': f1_score(y_test, y_pred),
'roc_auc': roc_auc_score(y_test, y_pred_proba)
}
# Log hasil evaluasi
logging.info(f"\n--- Evaluasi Model {symbol} ---")
for metric_name, value in metrics.items():
logging.info(f"{metric_name}: {value:.4f}")
logging.info("\nClassification Report:")
logging.info(f"\n{classification_report(y_test, y_pred)}")
# Simpan model dan metrik
model_path = MODEL_OUTPUT_PATH.format(symbol=symbol)
model.save_model(model_path)
# Simpan metrik evaluasi
metrics_path = f'model_metrics_{symbol}.json'
with open(metrics_path, 'w') as f:
json.dump(metrics, f, indent=4)
logging.info(f"Model dan metrik untuk {symbol} disimpan di {model_path} dan {metrics_path}")
return model, metrics
def main():
"""Fungsi utama untuk melatih model XGBoost untuk semua simbol."""
training_data = load_training_data(TRAINING_DATA_PATH)
if not training_data:
logging.error("Tidak ada data latih yang valid. Keluar.")
return
# Pisahkan data per simbol
trades_by_symbol = {}
for trade in training_data:
symbol = trade['symbol']
if symbol not in trades_by_symbol:
trades_by_symbol[symbol] = []
trades_by_symbol[symbol].append(trade)
# Latih model untuk setiap simbol
for symbol, trades in trades_by_symbol.items():
logging.info(f"Melatih model untuk simbol {symbol} ({len(trades)} trade)...")
train_model_for_symbol(symbol, trades)
logging.info("Pelatihan model selesai untuk semua simbol.")
if __name__ == '__main__':
main()