Skip to content

Commit

Permalink
Merge pull request #19 from gianlucamazza/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
gianlucamazza authored Aug 10, 2024
2 parents 2f07d0b + 4d80810 commit 9505733
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 397 deletions.
110 changes: 60 additions & 50 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,82 @@

with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
long_description = re.sub(r'!\[.*?\]\(.*?\)\n', '', long_description)

long_description = re.sub(r"!\[.*?\]\(.*?\)\n", "", long_description)

setup(
name='lstm_forecast',
version='0.1.2',
author='Gianluca Mazza',
author_email='[email protected]',
description='A package for LSTM-based financial time series forecasting',
name="lstm_forecast",
version="0.1.3",
author="Gianluca Mazza",
author_email="[email protected]",
description="A package for LSTM-based financial time series forecasting",
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/gianlucamazza/lstm_forecast',
long_description_content_type="text/markdown",
url="https://github.com/gianlucamazza/lstm_forecast",
project_urls={
'Bug Tracker': 'https://github.com/gianlucamazza/lstm_forecast/issues',
'Documentation': 'https://github.com/gianlucamazza/lstm_forecast#readme',
'Source Code': 'https://github.com/gianlucamazza/lstm_forecast',
"Bug Tracker": "https://github.com/gianlucamazza/lstm_forecast/issues",
"Documentation": "https://github.com/gianlucamazza/lstm_forecast#readme",
"Source Code": "https://github.com/gianlucamazza/lstm_forecast",
},
license='MIT',
license="MIT",
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development :: Libraries :: Python Modules',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Operating System :: OS Independent',
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Operating System :: OS Independent",
],
packages=find_packages(where='src'),
package_dir={'': 'src'},
python_requires='>=3.7',
packages=find_packages(where="src"),
package_dir={"": "src"},
python_requires=">=3.7",
install_requires=[
'pandas',
'ta',
'statsmodels',
'numpy',
'yfinance',
'matplotlib',
'torch',
'plotly',
'scikit-learn',
'xgboost',
'optuna',
'onnxruntime',
'onnx',
'flask',
"pandas>=2.2.2",
"ta>=0.11.0",
"statsmodels>=0.14.2",
"numpy>=1.26.4",
"yfinance>=0.2.41",
"matplotlib>=3.9.1",
"torch>=2.5.0",
"plotly>=5.3.1",
"scikit-learn>=1.5.1",
"xgboost>=2.1.0",
"optuna>=3.6.1",
"onnxruntime>=1.18.1",
"onnx>=1.16.2",
"Flask>=3.0.3",
],
extras_require={
"dev": [
"pytest>=8.3.2",
"sphinx>=7.4.7",
"twine>=5.1.1",
"black>=24.8.0",
"flake8>=7.1.1",
"pre-commit>=3.7.1",
],
},
entry_points={
'console_scripts': [
'lstm_forecast=lstm_forecast.cli:main',
"console_scripts": [
"lstm_forecast=lstm_forecast.cli:main",
],
},
include_package_data=True,
package_data={
'': ['*.json', '*.html', '*.png'],
"": ["*.json", "*.html", "*.png"],
},
extras_require={
'dev': [
'pytest',
'sphinx',
'twine',
"dev": [
"pytest",
"sphinx",
"twine",
],
},
keywords='lstm forecasting finance time series deep learning',
keywords="lstm forecasting finance time series deep learning",
)
2 changes: 1 addition & 1 deletion src/lstm_forecast/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from lstm_forecast.train import main as train_main
from lstm_forecast.predict import main as predict_main
from lstm_forecast.api.app import create_app
from lstm_forecast.api import create_app
from lstm_forecast.config import Config


Expand Down
86 changes: 50 additions & 36 deletions src/lstm_forecast/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,15 @@ def save_training_state(model, optimizer, epoch, best_val_loss, config):
def train_model(
config,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader,
num_epochs: int,
learning_rate: float,
model_dir: str,
weight_decay: float,
_device: torch.device,
fold_idx: int = None,
) -> None:
"""Train the model with early stopping."""
optimizer = torch.optim.Adam(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
loss_fn = torch.nn.MSELoss()

early_stopping = EarlyStopping(
Expand Down Expand Up @@ -238,44 +234,55 @@ def main(config: Config):
weight_decay=config.model_settings.get("weight_decay", 0.0),
)

global_early_stopping = EarlyStopping(
patience=10,
verbose=True,
path=f"{config.training_settings['model_dir']}/{config.data_settings['symbol']}_best_model.pth",
)

all_train_losses = []
all_val_losses = []

for fold_idx, (train_loader, val_loader) in enumerate(
train_val_loaders, 1
):
logger.info(f"Training fold {fold_idx}")
# Use the same model instance for all folds, just move it to the correct device
model.to(device)

# Train the model
for epoch in range(config.training_settings["epochs"]):
train_model(
config,
model,
train_loader,
val_loader,
num_epochs=config.training_settings["epochs"],
learning_rate=config.model_settings.get(
"learning_rate", 0.001
),
model_dir=config.training_settings["model_dir"],
weight_decay=config.model_settings.get(
"weight_decay", 0.0
),
_device=device,
fold_idx=fold_idx,
)
train_losses, val_losses = train_model(
config,
model,
optimizer,
train_loader,
val_loader,
num_epochs=config.training_settings["epochs"],
model_dir=config.training_settings["model_dir"],
_device=device,
fold_idx=fold_idx,
)

# Evaluate the model on the validation set
val_loss = evaluate_model(
model, val_loader, torch.nn.MSELoss(), device
)
all_train_losses.extend(train_losses)
all_val_losses.extend(val_losses)

if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = model.state_dict()
# Evaluate the model on the validation set
val_loss = evaluate_model(
model, val_loader, torch.nn.MSELoss(), device
)

save_training_state(
model, optimizer, epoch, best_val_loss, config
global_early_stopping(val_loss, model)
if global_early_stopping.early_stop:
logger.info(
"Global early stopping triggered. Stopping training."
)
break

if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = model.state_dict().copy()

save_training_state(
model, optimizer, fold_idx, best_val_loss, config
)

if best_model is not None:
# Save and export the best model
Expand All @@ -290,14 +297,21 @@ def main(config: Config):

# Save the training state including the optimizer state
save_training_state(
final_model, optimizer, epoch, best_val_loss, config
final_model, optimizer, fold_idx, best_val_loss, config
)
else:
logger.error("No best model found to export.")

plot_training_history(all_train_losses, all_val_losses, config)

except ValueError as ve:
logger.error(f"Value error occurred: {str(ve)}")
except torch.cuda.CudaError as ce:
logger.error(f"CUDA error occurred: {str(ce)}")
except Exception as e:
logger.error(f"An error occurred during training: {str(e)}")
raise
logger.error(f"An unexpected error occurred: {str(e)}")
finally:
logger.info("Training process completed.")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 9505733

Please sign in to comment.