mql5/gold/database_manager.py

441 Zeilen
16 KiB
Python

import sqlite3
import pandas as pd
import json
from datetime import datetime
import logging
import os
logger = logging.getLogger("DatabaseManager")
class DatabaseManager:
def __init__(self, db_path="trading_data.db"):
# 如果传入的是默认文件名(非绝对路径),将其转换为基于当前文件所在目录的绝对路径
if not os.path.isabs(db_path):
current_dir = os.path.dirname(os.path.abspath(__file__))
self.db_path = os.path.join(current_dir, os.path.basename(db_path))
else:
self.db_path = db_path
# 确保路径存在
logger.info(f"Database path: {self.db_path}")
self.conn = None
self._init_db()
def _get_connection(self):
"""Helper to get a database connection with proper timeout and retry"""
import time
max_retries = 5
# If we already have a connection, check if it's valid
if self.conn:
try:
self.conn.execute("SELECT 1")
return self.conn
except Exception:
# Connection might be broken, try to reconnect
try:
self.conn.close()
except: pass
self.conn = None
for i in range(max_retries):
try:
# check_same_thread=False allows Streamlit to use connection across threads
self.conn = sqlite3.connect(self.db_path, timeout=30.0, check_same_thread=False)
# Enable WAL mode for better concurrency immediately upon connection
self.conn.execute('PRAGMA journal_mode=WAL;')
return self.conn
except sqlite3.OperationalError as e:
if "unable to open database file" in str(e) or "database is locked" in str(e):
if i < max_retries - 1:
time.sleep(1.0) # Increase wait time
continue
logger.error(f"Connection failed after {i+1} attempts: {e}")
raise e
raise sqlite3.OperationalError("Failed to connect to database after retries")
def _init_db(self):
"""Initialize the database tables"""
try:
conn = self._get_connection()
cursor = conn.cursor()
# Table for market data (OHLCV)
cursor.execute('''
CREATE TABLE IF NOT EXISTS market_data (
timestamp DATETIME,
symbol TEXT,
timeframe TEXT,
open REAL,
high REAL,
low REAL,
close REAL,
volume INTEGER,
PRIMARY KEY (timestamp, symbol, timeframe)
)
''')
# Table for signals and analysis
cursor.execute('''
CREATE TABLE IF NOT EXISTS signals (
timestamp DATETIME,
symbol TEXT,
timeframe TEXT,
signal TEXT,
strength REAL,
source TEXT, -- 'CRT', 'PriceEq', 'DeepSeek', 'Qwen', 'Hybrid'
details TEXT -- JSON string with extra details
)
''')
# Table for trades
cursor.execute('''
CREATE TABLE IF NOT EXISTS trades (
ticket INTEGER PRIMARY KEY,
symbol TEXT,
action TEXT, -- 'BUY', 'SELL'
volume REAL,
price REAL,
time DATETIME,
result TEXT, -- 'OPEN', 'CLOSED', 'ERROR'
close_price REAL,
close_time DATETIME,
profit REAL,
mfe REAL, -- Maximum Favorable Excursion
mae REAL -- Maximum Adverse Excursion
)
''')
# Attempt to add columns if they don't exist (for existing DB)
try:
cursor.execute("ALTER TABLE trades ADD COLUMN close_price REAL")
except: pass
try:
cursor.execute("ALTER TABLE trades ADD COLUMN close_time DATETIME")
except: pass
try:
cursor.execute("ALTER TABLE trades ADD COLUMN profit REAL")
except: pass
try:
cursor.execute("ALTER TABLE trades ADD COLUMN mfe REAL")
except: pass
try:
cursor.execute("ALTER TABLE trades ADD COLUMN mae REAL")
except: pass
# Table for account metrics (Balance, Equity, etc.)
cursor.execute('''
CREATE TABLE IF NOT EXISTS account_metrics (
timestamp DATETIME,
balance REAL,
equity REAL,
margin REAL,
free_margin REAL,
margin_level REAL,
total_profit REAL, -- Floating PnL
symbol_pnl REAL, -- PnL for specific symbol (optional)
PRIMARY KEY (timestamp)
)
''')
conn.commit()
# conn.close() # Persistent connection, do not close
except Exception as e:
logger.error(f"Database initialization failed: {e}")
def save_account_metrics(self, metrics):
"""Save account metrics"""
try:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('''
INSERT INTO account_metrics (timestamp, balance, equity, margin, free_margin, margin_level, total_profit, symbol_pnl)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
metrics.get('timestamp', datetime.now()),
metrics.get('balance', 0),
metrics.get('equity', 0),
metrics.get('margin', 0),
metrics.get('free_margin', 0),
metrics.get('margin_level', 0),
metrics.get('total_profit', 0),
metrics.get('symbol_pnl', 0)
))
conn.commit()
except Exception as e:
logger.error(f"Failed to save account metrics: {e}")
def get_latest_account_metrics(self):
"""Get the most recent account metrics"""
try:
conn = self._get_connection()
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute('SELECT * FROM account_metrics ORDER BY timestamp DESC LIMIT 1')
row = cursor.fetchone()
if row:
return dict(row)
return None
except Exception as e:
logger.error(f"Failed to get account metrics: {e}")
return None
def get_historical_account_metrics(self, hours_ago=24):
"""Get account metrics from N hours ago"""
try:
conn = self._get_connection()
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# SQLite datetime modifier
cursor.execute(f"SELECT * FROM account_metrics WHERE timestamp <= datetime('now', '-{hours_ago} hours') ORDER BY timestamp DESC LIMIT 1")
row = cursor.fetchone()
if row:
return dict(row)
return None
except Exception as e:
logger.error(f"Failed to get historical metrics: {e}")
return None
def get_start_of_day_metrics(self):
"""Get account metrics from the start of the current day"""
try:
conn = self._get_connection()
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("SELECT * FROM account_metrics WHERE date(timestamp) = date('now') ORDER BY timestamp ASC LIMIT 1")
row = cursor.fetchone()
if row:
return dict(row)
# If no record today, try last record of yesterday
cursor.execute("SELECT * FROM account_metrics WHERE date(timestamp) < date('now') ORDER BY timestamp DESC LIMIT 1")
row = cursor.fetchone()
if row:
return dict(row)
return None
except Exception as e:
logger.error(f"Failed to get start of day metrics: {e}")
return None
def save_market_data(self, df, symbol, timeframe):
try:
conn = self._get_connection()
# Add symbol and timeframe columns if they don't exist in df
data_to_save = df.copy()
data_to_save['symbol'] = symbol
data_to_save['timeframe'] = timeframe
data_to_save['timestamp'] = data_to_save.index
cursor = conn.cursor()
for index, row in data_to_save.iterrows():
cursor.execute('''
REPLACE INTO market_data (timestamp, symbol, timeframe, open, high, low, close, volume)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (row['timestamp'].to_pydatetime(), symbol, timeframe, row['open'], row['high'], row['low'], row['close'], row['volume']))
conn.commit()
# conn.close()
except Exception as e:
logger.error(f"Failed to save market data: {e}")
def save_signal(self, symbol, timeframe, signal_data):
"""Save analysis signal"""
try:
conn = self._get_connection()
cursor = conn.cursor()
timestamp = datetime.now()
# Save Hybrid result
cursor.execute('''
INSERT INTO signals (timestamp, symbol, timeframe, signal, strength, source, details)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (timestamp, symbol, timeframe, signal_data['final_signal'], signal_data['strength'], 'Hybrid', json.dumps(signal_data['details'])))
conn.commit()
# conn.close()
except Exception as e:
logger.error(f"Failed to save signal: {e}")
def save_trade(self, trade_data):
"""Save trade execution"""
try:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO trades (ticket, symbol, action, volume, price, time, result)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
trade_data['ticket'],
trade_data['symbol'],
trade_data['action'],
trade_data['volume'],
trade_data['price'],
datetime.now(),
trade_data.get('result', 'OPEN')
))
conn.commit()
# conn.close()
except Exception as e:
logger.error(f"Failed to save trade: {e}")
def update_trade_performance(self, ticket, close_data):
"""Update trade with close info and performance metrics"""
try:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('''
UPDATE trades
SET result = 'CLOSED',
close_price = ?,
close_time = ?,
profit = ?,
mfe = ?,
mae = ?
WHERE ticket = ?
''', (
close_data.get('close_price'),
close_data.get('close_time'),
close_data.get('profit'),
close_data.get('mfe'),
close_data.get('mae'),
ticket
))
conn.commit()
# conn.close()
except Exception as e:
logger.error(f"Failed to update trade performance: {e}")
def get_trade_performance_stats(self, limit=50):
"""Get aggregated performance stats (MFE, MAE) for recent closed trades"""
try:
conn = self._get_connection()
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute('''
SELECT
AVG(mfe) as avg_mfe,
AVG(mae) as avg_mae,
AVG(profit) as avg_profit,
COUNT(*) as count
FROM trades
WHERE result = 'CLOSED' AND mfe IS NOT NULL
ORDER BY close_time DESC
LIMIT ?
''', (limit,))
row = cursor.fetchone()
stats = None
if row and row['count'] > 0:
stats = {
"avg_mfe": row['avg_mfe'],
"avg_mae": row['avg_mae'],
"avg_profit": row['avg_profit'],
"trade_count": row['count']
}
# 获取最近的 MFE/MAE 数据列表用于分布分析
cursor.execute('''
SELECT ticket, mfe, mae, profit, action
FROM trades
WHERE result = 'CLOSED' AND mfe IS NOT NULL
ORDER BY close_time DESC
LIMIT ?
''', (limit,))
rows = cursor.fetchall()
details = []
for r in rows:
details.append({
"ticket": r['ticket'],
"mfe": r['mfe'],
"mae": r['mae'],
"profit": r['profit'],
"action": r['action']
})
if stats:
stats["recent_trades"] = details
return stats
except Exception as e:
logger.error(f"Failed to get trade stats: {e}")
return None
def get_open_trades(self):
"""Get all trades that are currently marked as OPEN in the database"""
try:
conn = self._get_connection()
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM trades
WHERE result = 'OPEN' OR result IS NULL
''')
rows = cursor.fetchall()
# conn.close()
# Convert rows to list of dicts
trades = []
for row in rows:
trades.append(dict(row))
return trades
except Exception as e:
logger.error(f"Failed to get open trades: {e}")
return []
def get_latest_signals(self, limit=50):
"""Get recent signals for dashboard"""
try:
conn = self._get_connection()
query = "SELECT * FROM signals ORDER BY timestamp DESC LIMIT ?"
df = pd.read_sql_query(query, conn, params=(limit,))
# conn.close()
return df
except Exception as e:
logger.error(f"Failed to get signals: {e}")
return pd.DataFrame()
def get_market_data(self, symbol, limit=100):
"""Get market data for dashboard"""
try:
conn = self._get_connection()
query = "SELECT * FROM market_data WHERE symbol = ? ORDER BY timestamp DESC LIMIT ?"
df = pd.read_sql_query(query, conn, params=(symbol, limit))
# conn.close()
if not df.empty:
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.sort_values('timestamp')
return df
except Exception as e:
logger.error(f"Failed to get market data: {e}")
return pd.DataFrame()
def get_trades(self, limit=50):
"""Get trade history from local DB"""
try:
conn = self._get_connection()
query = "SELECT * FROM trades ORDER BY time DESC LIMIT ?"
df = pd.read_sql_query(query, conn, params=(limit,))
# conn.close()
return df
except Exception as e:
logger.error(f"Failed to get trades: {e}")
return pd.DataFrame()