Skip to content

Commit

Permalink
Implement GA for hyperparameter optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
noahbclarkson committed Nov 29, 2024
1 parent 93d103a commit 567fac6
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 44 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ linfa = "0.7"
linfa-pls = "0.7"
ndarray = "0.15"
derive_builder = "0.20"
genevo = "0.7"

[dev-dependencies]
tempfile = "3.14"
9 changes: 5 additions & 4 deletions src/algorithm/algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl Algorithm {
let total_size = ds.len()?;
let test_data_size = total_size / count;

let test_results: Vec<TestData> = (1..count)
let test_results: Vec<TestData> = (0..count)
.map(|i| -> Result<TestData, KryptoError> {
let start = i * test_data_size;
let end = match i == count - 1 {
Expand All @@ -84,10 +84,11 @@ impl Algorithm {
};
let features = ds.get_features();
let candles = ds.get_candles();
let labels = ds.get_labels();
let test_features = &features[start..end];
let test_candles = &candles[start..end];
let train_features = features[..start].to_vec();
let train_labels = ds.get_labels()[..start].to_vec();
let train_features = [&features[..start], &features[end..]].concat();
let train_labels = [&labels[..start], &labels[end..]].concat();

let pls = get_pls(train_features, train_labels, settings.n)?;
let predictions = predict(&pls, test_features)?;
Expand All @@ -107,7 +108,7 @@ impl Algorithm {
let median_return = median(&TestData::get_monthly_returns(&test_results));
let median_accuracy = median(&TestData::get_accuracies(&test_results));
let result = AlgorithmResult::new(median_return, median_accuracy);
info!("Backtest result: {}", result);
debug!("Backtest result: {}", result);
Ok(result)
}

Expand Down
21 changes: 21 additions & 0 deletions src/data/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,26 @@ impl IntervalData {

SymbolDataset::new(features, labels, candles)
}

pub fn get_specific_tickers(&self, tickers: &Vec<String>) -> Self {
// Create a new symbol_data_map with only the specified tickers
let mut new_symbol_data_map: HashMap<String, RawSymbolData> = HashMap::new();

for ticker in tickers {
if let Some(symbol_data) = self.symbol_data_map.get(ticker) {
new_symbol_data_map.insert(ticker.clone(), symbol_data.clone());
}
}

// Recompute the normalized predictors with the new symbol data
let records = get_records(&new_symbol_data_map);
let normalized_predictors = get_normalized_predictors(records);

Self {
symbol_data_map: new_symbol_data_map,
normalized_predictors,
}
}
}

fn get_normalized_predictors(records: Vec<Vec<f64>>) -> Vec<Vec<f64>> {
Expand Down Expand Up @@ -253,6 +273,7 @@ impl SymbolDataset {
}
}

#[derive(Debug, Clone)]
struct RawSymbolData {
candles: Vec<Candlestick>,
technicals: Vec<Technicals>,
Expand Down
2 changes: 1 addition & 1 deletion src/data/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::error::ParseIntervalError;

/// Represents various time intervals.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, PartialOrd)]
pub enum Interval {
OneMinute,
ThreeMinutes,
Expand Down
34 changes: 33 additions & 1 deletion src/data/technicals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::candlestick::Candlestick;

use ta::{indicators::*, Next};

pub const TECHNICAL_COUNT: usize = 8;
pub const TECHNICAL_COUNT: usize = 10;

#[derive(Debug, Clone)]
pub struct Technicals {
Expand All @@ -14,6 +14,8 @@ pub struct Technicals {
efficiency_ratio: f64,
percentage_change_ema: f64,
volume_percentage_change_ema: f64,
bb_pct: f64,
candlestick_ratio: f64,
}

impl Technicals {
Expand All @@ -26,10 +28,12 @@ impl Technicals {
let mut efficiency_ratio = EfficiencyRatio::default();
let mut pc_ema = PercentageChangeEMA::default();
let mut volume_pc_ema = PercentageChangeEMA::default();
let mut bollinger_bands = BollingerBands::default();

let mut result = Vec::new();

for candle in data {
let bb = bollinger_bands.next(candle.close);
let technicals = Self {
rsi: rsi.next(candle),
fast_stochastic: fast_stochastic.next(candle),
Expand All @@ -39,6 +43,8 @@ impl Technicals {
efficiency_ratio: efficiency_ratio.next(candle),
percentage_change_ema: pc_ema.next(candle.close),
volume_percentage_change_ema: volume_pc_ema.next(candle.volume),
bb_pct: (candle.close - bb.lower) / (bb.upper - bb.lower),
candlestick_ratio: candlestick_ratio(candle),
};
result.push(technicals);
}
Expand All @@ -55,6 +61,8 @@ impl Technicals {
self.efficiency_ratio,
self.percentage_change_ema,
self.volume_percentage_change_ema,
self.bb_pct,
self.candlestick_ratio,
]
}
}
Expand Down Expand Up @@ -100,3 +108,27 @@ impl PercentageChangeEMA {
}
}
}

/**
Calculates the candlestick ratio for a given candlestick.
The formula is: tanh((upper_wick / body) - (lower_wick / body))
## Arguments
- `candle`: A Candlestick struct.
## Returns
The candlestick ratio.
*/
fn candlestick_ratio(candle: &Candlestick) -> f64 {
let top = candle.close.max(candle.open);
let bottom = candle.close.min(candle.open);
let upper_wick = candle.high - top;
let lower_wick = bottom - candle.low;
let body = top - bottom;
if body == 0.0 {
return 0.0;
}
let ratio = (upper_wick / body) - (lower_wick / body);
ratio.tanh()
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ pub mod config;
pub mod data;
pub mod error;
pub mod logging;
pub mod optimisation;
pub mod util;
138 changes: 100 additions & 38 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
use core::f64;
use std::sync::Arc;

use genevo::{
ga::genetic_algorithm,
operator::prelude::{ElitistReinserter, MaximizeSelector},
prelude::{build_population, simulate, GenerationLimit, Population},
simulation::*,
};
use krypto::{
algorithm::algo::{Algorithm, AlgorithmSettings},
config::KryptoConfig,
data::dataset::Dataset,
error::KryptoError,
logging::setup_tracing,
optimisation::{
TradingStrategy, TradingStrategyCrossover, TradingStrategyFitnessFunction,
TradingStrategyGenomeBuilder, TradingStrategyMutation,
},
};
use tracing::{error, info};

const MAX_N: usize = 50;
const MAX_DEPTH: usize = 50;
const MAX_DEPTH: usize = 40;

pub fn main() {
let (_, file_guard) = setup_tracing(Some("logs")).expect("Failed to set up tracing");
Expand All @@ -25,48 +35,100 @@ fn run() -> Result<(), KryptoError> {
let config = KryptoConfig::read_config::<&str>(None)?;
let dataset = Dataset::load(&config)?;

let mut best_return = f64::NEG_INFINITY;
let mut best_algorithm: Option<Algorithm> = None;
let population_size = 250;
let selection_ratio = 0.7;
let num_individuals_per_parents = 2;
let mutation_rate = 0.015;
let reinsertion_ratio = 0.7;
let generation_limit = 100; // Adjust as needed

let mut csv = csv::Writer::from_path("results.csv")?;
let available_tickers = config.symbols.clone();
let available_intervals = config.intervals.clone();

csv.write_record([
"n",
"depth",
"ticker",
"monthly_return",
"accuracy",
"interval",
])?;
let config = Arc::new(config);
let dataset = Arc::new(dataset);

for (interval, interval_data) in dataset.get_map() {
info!("Interval: {}", interval);
let all_settings = AlgorithmSettings::all(config.symbols.clone(), MAX_N, MAX_DEPTH);
for settings in all_settings {
let algorithm = Algorithm::load(interval_data, settings.clone(), &config)?;
let monthly_return = algorithm.get_monthly_return();
csv.write_record(&[
settings.n.to_string(),
settings.depth.to_string(),
settings.symbol.to_string(),
monthly_return.to_string(),
algorithm.get_accuracy().to_string(),
interval.to_string(),
])?;
let initial_population: Population<TradingStrategy> = build_population()
.with_genome_builder(TradingStrategyGenomeBuilder::new(
available_tickers.clone(),
available_intervals.clone(),
MAX_N,
MAX_DEPTH,
))
.of_size(population_size)
.uniform_at_random();

if monthly_return > best_return {
best_return = monthly_return;
info!("New best algorithm: {}", &algorithm);
best_algorithm = Some(algorithm);
}
let ga = genetic_algorithm()
.with_evaluation(TradingStrategyFitnessFunction::new(
config.clone(),
dataset.clone(),
))
.with_selection(MaximizeSelector::new(
selection_ratio,
num_individuals_per_parents,
))
.with_crossover(TradingStrategyCrossover)
.with_mutation(TradingStrategyMutation::new(
mutation_rate,
available_tickers.clone(),
available_intervals.clone(),
MAX_N,
MAX_DEPTH,
))
.with_reinsertion(ElitistReinserter::new(
TradingStrategyFitnessFunction::new(config, dataset),
true,
reinsertion_ratio,
))
.with_initial_population(initial_population)
.build();

csv.flush()?;
}
}
let mut sim = simulate(ga)
.until(GenerationLimit::new(generation_limit))
.build();

info!("Starting Genetic Algorithm");

match best_algorithm {
Some(algorithm) => info!("Best Algorithm: {}", &algorithm),
None => info!("No algorithm found."),
let mut csv = csv::Writer::from_path("ga-results.csv")?;
csv.write_record(["Generation", "Fitness", "Strategy"])?;
csv.flush()?;
// Run the simulation loop
loop {
match sim.step() {
Ok(SimResult::Intermediate(step)) => {
let best_solution = &step.result.best_solution;
info!(
"Generation {}: Best fitness: {:.2}%, Strategy: {:?}",
step.iteration, best_solution.solution.fitness as f64 / 100.0, best_solution.solution.genome
);
csv.write_record([
step.iteration.to_string(),
(best_solution.solution.fitness as f64 / 100.0).to_string(),
best_solution.solution.genome.to_string(),
])?;
csv.flush()?;
}
Ok(SimResult::Final(step, processing_time, _, stop_reason)) => {
let best_solution = &step.result.best_solution;
info!(
"Simulation ended: {} in {}s",
stop_reason,
processing_time.duration().num_seconds()
);
info!(
"Best strategy found in generation {}: Fitness: {:.2}%",
best_solution.generation,
best_solution.solution.fitness as f64 / 100.0
);
// Display the best trading strategy
info!("Best trading strategy: {}", best_solution.solution.genome);
break;
}
Err(error) => {
error!("Error: {}", error);
break;
}
}
}

Ok(())
Expand Down
Loading

0 comments on commit 567fac6

Please sign in to comment.