intelligent-trading-bot/scripts/train.py

147 lines
5 KiB
Python

2022-03-20 10:09:33 +01:00
from pathlib import Path
from datetime import datetime, timezone, timedelta
import click
2022-07-24 11:06:33 +02:00
from tqdm import tqdm
2022-03-20 10:09:33 +01:00
import numpy as np
import pandas as pd
from service.App import *
from common.model_store import *
from common.generators import train_feature_set
2022-03-20 10:09:33 +01:00
"""
2022-07-16 11:41:27 +02:00
Train models for all target labels and all algorithms declared in the configuration using the specified features.
2022-03-20 10:09:33 +01:00
"""
@click.command()
@click.option('--config_file', '-c', type=click.Path(), default='', help='Configuration file name')
def main(config_file):
load_config(config_file)
config = App.config
App.model_store = ModelStore(config)
App.model_store.load_models()
2022-03-20 10:09:33 +01:00
time_column = config["time_column"]
2022-03-20 10:09:33 +01:00
2022-07-16 11:41:27 +02:00
now = datetime.now()
2022-04-15 21:45:46 +02:00
symbol = config["symbol"]
data_path = Path(config["data_folder"]) / symbol
# Determine desired data length depending on train/predict mode
is_train = config.get("train")
if is_train:
window_size = config.get("train_length")
print(f"WARNING: Train mode is specified although this script is intended for prediction and will not train models.")
else:
window_size = config.get("predict_length")
features_horizon = config.get("features_horizon")
if window_size:
window_size += features_horizon
2022-03-20 10:09:33 +01:00
#
# Load feature matrix
#
file_path = data_path / config.get("matrix_file_name")
2022-07-16 11:41:27 +02:00
if not file_path.is_file():
print(f"ERROR: Input file does not exist: {file_path}")
2022-03-20 10:09:33 +01:00
return
2022-07-16 11:41:27 +02:00
print(f"Loading data from source data file {file_path}...")
if file_path.suffix == ".parquet":
df = pd.read_parquet(file_path)
elif file_path.suffix == ".csv":
df = pd.read_csv(file_path, parse_dates=[time_column], date_format="ISO8601")
else:
print(f"ERROR: Unknown extension of the input file '{file_path.suffix}'. Only 'csv' and 'parquet' are supported")
return
2022-03-20 10:09:33 +01:00
print(f"Finished loading {len(df)} records with {len(df.columns)} columns from the source file {file_path}")
# Select only the data necessary for analysis
if window_size:
df = df.tail(window_size)
df = df.reset_index(drop=True)
2022-03-20 10:09:33 +01:00
2023-09-02 11:42:12 +02:00
print(f"Input data size {len(df)} records. Range: [{df.iloc[0][time_column]}, {df.iloc[-1][time_column]}]")
2022-07-16 11:41:27 +02:00
#
# Prepare data by selecting columns and rows
#
# Default (common) values for all trained features
train_features_all = config.get("train_features")
labels_all = config["labels"]
2022-03-20 10:09:33 +01:00
# Select necessary features and labels
2025-03-25 20:04:36 +01:00
out_columns = [time_column, 'open', 'high', 'low', 'close', 'volume', 'close_time']
2022-07-17 10:04:20 +02:00
out_columns = [x for x in out_columns if x in df.columns]
all_features = train_features_all + labels_all
2022-12-18 10:52:17 +01:00
df = df[out_columns + [x for x in all_features if x not in out_columns]]
2022-07-16 11:41:27 +02:00
for label in labels_all:
if np.issubdtype(df[label].dtype, bool):
df[label] = df[label].astype(int) # For classification tasks we want to use integers
2022-03-20 10:09:33 +01:00
label_horizon = config["label_horizon"] # Labels are generated from future data and hence we might want to explicitly remove some tail rows
train_length = config.get("train_length")
# Remove the tail data for which no (correct) labels are available
# The reason is that these labels are computed from future values which are not available and hence labels might be wrong
2022-04-02 11:50:07 +02:00
if label_horizon:
2022-07-16 11:41:27 +02:00
df = df.head(-label_horizon)
# Limit maximum length for all algorithms (algorithms can further limit their train size)
if train_length:
df = df.tail(train_length)
# Handle NULLs
2024-03-16 12:01:49 +01:00
df.replace([np.inf, -np.inf], np.nan, inplace=True)
na_df = df[ df[train_features_all].isna().any(axis=1) ]
if len(na_df) > 0:
print(f"WARNING: There exist {len(na_df)} rows with NULLs in some feature columns")
2022-07-23 09:12:34 +02:00
df = df.reset_index(drop=True) # To remove gaps in index before use
#
# Train feature models
#
train_feature_sets = config.get("train_feature_sets", [])
if not train_feature_sets:
print(f"ERROR: no train feature sets defined. Nothing to process.")
return
print(f"Start training models for {len(df)} input records.")
2022-03-20 10:09:33 +01:00
models = dict()
for i, fs in enumerate(train_feature_sets):
fs_now = datetime.now()
print(f"Start train feature set {i}/{len(train_feature_sets)}. Generator {fs.get('generator')}...")
2022-07-16 11:41:27 +02:00
fs_models = train_feature_set(df, fs, config)
models.update(fs_models)
2022-07-16 11:41:27 +02:00
fs_elapsed = datetime.now() - fs_now
print(f"Finished train feature set {i}/{len(train_feature_sets)}. Generator {fs.get('generator')}. Time: {str(fs_elapsed).split('.')[0]}")
2022-07-16 11:41:27 +02:00
print(f"Finished training models.")
2022-03-20 10:09:33 +01:00
#
# Store all collected models in files
#
for score_column_name, model_pair in models.items():
App.model_store.put_model_pair(score_column_name, model_pair)
2022-03-20 10:09:33 +01:00
print(f"Models stored in path: {App.model_store.model_path.absolute()}")
2022-03-20 10:09:33 +01:00
#
# End
#
2022-07-16 11:41:27 +02:00
elapsed = datetime.now() - now
print(f"Finished training models in {str(elapsed).split('.')[0]}")
2022-03-20 10:09:33 +01:00
if __name__ == '__main__':
main()