- Overview
- Features
- Installation
- Data Requirements
- Parameters
- Usage
- Output
- Models
- Logging
- License
- Acknowledgments
This Python script performs stock signal prediction and payoff analysis using machine learning classifiers. It processes financial data to generate buy signals for securities, evaluates model performance, and visualizes the payoff over a specified period. The script leverages various libraries for data manipulation, feature engineering, model training, and visualization.
- Data Preprocessing: Reads and preprocesses sector and financial data.
- Feature Engineering: Calculates the Sharpe ratio as an additional feature.
- Scaling: Normalizes features using Min-Max scaling.
- Model Training: Utilizes CatBoost and XGBoost classifiers with cross-validation.
- Signal Generation: Predicts buy signals based on model probabilities.
- Performance Evaluation: Calculates accuracy of predictions and tracks payoff.
- Visualization: Generates interactive plots for signals and payoff using Plotly.
- Python 3.7 or higher
git clone https://github.com/yourusername/stock-signal-prediction.git
cd stock-signal-prediction
It's recommended to use a virtual environment.
# Create a virtual environment
python -m venv venv
# Activate the virtual environment
# On Windows
venv\Scripts\activate
# On Unix or MacOS
source venv/bin/activate
# Install required packages
pip install -r requirements.txt
Alternatively, install packages manually:
pip install numpy pandas plotly scikit-learn catboost xgboost
The script requires the following CSV data files placed in a data/
directory:
-
Sector Data:
data/data0.csv
- Contains sector information for securities.
-
Financial Data:
data/data1.csv
- Should include columns such as
date
,security
,price
,return30
, and any other relevant features.
- Should include columns such as
-
Returns Data:
data/returns.csv
- Contains historical return data for securities with columns like
date
,security
, andreturn1
.
- Contains historical return data for securities with columns like
Ensure that dates are in a consistent format and that all required columns are present.
The script includes several configurable parameters:
-
Training and Testing Periods
start_train
: Start date for training data (e.g.,2017-01-01
)end_train
: End date for training data (e.g.,2023-11-30
)start_test
: Start date for testing data (e.g.,2024-01-01
)end_test
: End date for testing data (e.g.,2024-06-30
)
-
Modeling Parameters
n_buys
: Number of top buy signals to select (default:10
)verbose
: Toggle for detailed logging (default:True
)retrain_interval
: Frequency of model retraining in steps (default:300
)
These parameters can be adjusted within the script to suit different analysis periods and modeling preferences.
-
Prepare Data: Ensure all required CSV files are in the
data/
directory. -
Configure Parameters: Modify the script parameters as needed for your analysis.
-
Run the Script:
python stock_signal_prediction.py
Replace
stock_signal_prediction.py
with the actual script filename. -
View Outputs: Interactive plots will display automatically. Logs will be printed to the console.
The script generates the following outputs:
- Accuracy Metrics: Displays the accuracy of predictions for each test date.
- Buy Signals Plots:
- Line plot for individual securities (e.g., AAPL).
- Heatmap showing buy signals across securities and dates.
- Payoff Plot: Cumulative payoff over the testing period.
- Payoff Summary: Prints the total payoff percentage for the selected buy signals.
All plots are interactive and rendered using Plotly for enhanced visualization.
The script employs the following machine learning classifiers:
-
CatBoostClassifier
- Handles categorical features effectively.
- Configured with
iterations=100
andrandom_seed=23
.
-
XGBoostClassifier (Commented Out)
- Can be enabled by uncommenting the relevant lines.
- Configured with
n_estimators=100
,learning_rate=0.1
,random_state=42
, andverbosity=0
.
Both models are trained using Stratified K-Fold cross-validation with n_splits=5
to ensure robust performance.
The script includes a custom logging function to timestamp messages:
def log(message):
print(f'{datetime.datetime.now()} - {message}')
Key milestones and actions are logged to the console, including script start and end times, data setup, processing steps, and performance metrics.
This project is licensed under the MIT License.
-
Libraries:
-
Community: Thanks to the open-source community for providing powerful tools and resources that make projects like this possible.
For any questions or contributions, please open an issue or submit a pull request on the GitHub repository.