Skip to content

Commit

Permalink
refactor: Improve optimisation techniques
Browse files Browse the repository at this point in the history
  • Loading branch information
noahbclarkson committed Dec 2, 2024
1 parent b6822a7 commit 92bbf9f
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 86 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ edition = "2021"

[dependencies]
binance-rs-async = "1.3.3"
serde = { version = "1.0", features = ["derive"] }
serde = { version = "1.0.215", features = ["derive"] }
serde_yaml = "0.9"
chrono = "0.4"
ta = "0.5"
csv = "1.3"
csv = "1.3.1"
thiserror = "2.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = [
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = [
"fmt",
"env-filter",
"chrono",
Expand Down
12 changes: 11 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
fmt,
fs::File,
io::{BufReader, Write as _},
path::Path,
Expand Down Expand Up @@ -153,7 +154,7 @@ impl KryptoConfig {
}
}
}
info!("Configuration loaded successfully");
info!("Configuration loaded successfully: {}", config);
Ok(config)
}

Expand Down Expand Up @@ -192,3 +193,12 @@ impl KryptoConfig {
T::new(self.api_key.clone(), self.api_secret.clone())
}
}

impl fmt::Display for KryptoConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Start Date: {}\nSymbols: {:?}\nIntervals: {:?}\nCross Validations: {}\nFee: {:?}\nMax N: {}\nMax Depth: {}\nGeneration Limit: {}\nPopulation Size: {}\nMutation Rate: {}\nTechnicals: {:?}\nMargin: {}", self.start_date, self.symbols, self.intervals, self.cross_validations, self.fee, self.max_n, self.max_depth, self.generation_limit, self.population_size, self.mutation_rate, self.technicals, self.margin
)
}
}
Empty file removed src/krypto_account.rs
Empty file.
16 changes: 8 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async fn run() -> Result<(), KryptoError> {

let available_tickers = config.symbols.clone();
let available_intervals = config.intervals.clone();
let available_tecnicals = config.technicals.clone();
let available_technicals = config.technicals.clone();

let config = Arc::new(config);
let dataset = Arc::new(dataset);
Expand All @@ -49,9 +49,9 @@ async fn run() -> Result<(), KryptoError> {
.with_genome_builder(TradingStrategyGenomeBuilder::new(
available_tickers.clone(),
available_intervals.clone(),
available_tecnicals.clone(),
config.max_n,
available_technicals.clone(),
config.max_depth,
config.max_n,
))
.of_size(config.population_size)
.uniform_at_random();
Expand All @@ -61,7 +61,7 @@ async fn run() -> Result<(), KryptoError> {
config.clone(),
dataset.clone(),
available_tickers.clone(),
available_tecnicals.clone(),
available_technicals.clone(),
))
.with_selection(MaximizeSelector::new(
selection_ratio,
Expand All @@ -74,15 +74,15 @@ async fn run() -> Result<(), KryptoError> {
config.mutation_rate,
available_tickers.clone(),
available_intervals.clone(),
config.max_n,
config.max_depth,
config.max_n,
))
.with_reinsertion(ElitistReinserter::new(
TradingStrategyFitnessFunction::new(
config.clone(),
dataset,
available_tickers.clone(),
available_tecnicals.clone(),
available_technicals.clone(),
),
true,
reinsertion_ratio,
Expand All @@ -107,7 +107,7 @@ async fn run() -> Result<(), KryptoError> {
let phenotype = best_solution
.solution
.genome
.to_phenotype(&available_tickers, &available_tecnicals);
.to_phenotype(&available_tickers, &available_technicals);
info!(
"Generation {}: Best fitness: {:.2}%, Strategy: {:?}",
step.iteration,
Expand Down Expand Up @@ -136,7 +136,7 @@ async fn run() -> Result<(), KryptoError> {
let phenotype = best_solution
.solution
.genome
.to_phenotype(&available_tickers, &available_tecnicals);
.to_phenotype(&available_tickers, &available_technicals);
// Display the best trading strategy
info!("Best trading strategy: {}", phenotype);
break;
Expand Down
158 changes: 85 additions & 73 deletions src/optimisation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::{fmt, panic, sync::Arc};
use std::{
collections::HashMap,
fmt, panic,
sync::{Arc, Mutex},
};

use genevo::{
genetic::{Children, Parents},
Expand All @@ -14,7 +18,7 @@ use crate::{
data::{dataset::Dataset, interval::Interval},
};

#[derive(Clone, Debug, PartialEq, PartialOrd)]
#[derive(Clone, Debug, PartialEq, PartialOrd, Hash, Eq)]
pub struct TradingStrategyGenome {
n: usize,
d: usize,
Expand Down Expand Up @@ -137,23 +141,32 @@ impl TradingStrategyGenomeBuilder {
where
R: Rng + Sized,
{
let n = self.available_tickers.len();
let mut tickers = Vec::new();
for _ in 0..n {
tickers.push(r.gen_bool(0.25));
let num_selected = r.gen_range(1..=self.available_tickers.len());
let mut tickers = vec![false; self.available_tickers.len()];
for _ in 0..num_selected {
let pos = r.gen_range(0..self.available_tickers.len());
tickers[pos] = true;
}
// Ensure the symbol is included
let symbol = self.available_tickers.choose(r).unwrap().clone();
let pos_of_symbol = self
.available_tickers
.iter()
.position(|s| s == &symbol)
.unwrap();
tickers[pos_of_symbol] = true;
(tickers, symbol)
}

fn technicals<R>(&self, r: &mut R) -> Vec<bool>
where
R: Rng + Sized,
{
let n = self.available_technicals.len();
let mut technicals = Vec::new();
for _ in 0..n {
technicals.push(r.gen_bool(0.5));
let num_selected = r.gen_range(1..=self.available_technicals.len());
let mut technicals = vec![false; self.available_technicals.len()];
for _ in 0..num_selected {
let pos = r.gen_range(0..self.available_technicals.len());
technicals[pos] = true;
}
technicals
}
Expand All @@ -164,21 +177,14 @@ impl GenomeBuilder<TradingStrategyGenome> for TradingStrategyGenomeBuilder {
where
R: Rng + Sized,
{
let (mut tickers, symbol) = self.tickers(rng);
let pos_of_symbol = self
.available_tickers
.iter()
.position(|s| s == &symbol)
.unwrap();
tickers[pos_of_symbol] = true;
let (tickers, symbol) = self.tickers(rng);
let depth = rng.gen_range(1..=self.max_depth);
let technicals = self.technicals(rng);
let technical_count = technicals.iter().filter(|b| **b).count();
let tickers_count = tickers.iter().filter(|b| **b).count();
let max_n = depth * tickers_count * technical_count;
let n = rng.gen_range(1..=max_n.min(self.max_n));
let interval = *self.available_intervals.choose(rng).unwrap();

TradingStrategyGenome {
n,
d: depth,
Expand All @@ -196,6 +202,7 @@ pub struct TradingStrategyFitnessFunction {
dataset: Arc<Dataset>,
available_tickers: Vec<String>,
available_technicals: Vec<String>,
fitness_cache: Arc<Mutex<HashMap<TradingStrategyGenome, i64>>>,
}

impl fmt::Debug for TradingStrategyFitnessFunction {
Expand All @@ -211,45 +218,49 @@ impl TradingStrategyFitnessFunction {
available_tickers: Vec<String>,
available_technicals: Vec<String>,
) -> Self {
let fitness_cache = Arc::new(Mutex::new(HashMap::new()));
Self {
config,
dataset,
available_tickers,
available_technicals,
fitness_cache,
}
}

pub fn to_phenotype(&self, genome: &TradingStrategyGenome) -> TradingStrategy {
let tickers = genome
.tickers
.iter()
.zip(self.available_tickers.iter())
.filter_map(|(b, s)| if *b { Some(s.clone()) } else { None })
.collect();
let technicals = genome
.technicals
.iter()
.zip(self.available_technicals.iter())
.filter_map(|(b, s)| if *b { Some(s.clone()) } else { None })
.collect();
TradingStrategy::new(
genome.n,
genome.d,
genome.interval,
tickers,
genome.symbol.clone(),
technicals,
)
pub fn clear_cache(&self) {
let mut cache = self.fitness_cache.lock().unwrap();
cache.clear();
}

pub fn get_cache(&self) -> std::sync::MutexGuard<HashMap<TradingStrategyGenome, i64>> {
self.fitness_cache.lock().unwrap()
}

pub fn cache(&self, genome: TradingStrategyGenome, fitness: i64) {
let mut cache = self.fitness_cache.lock().unwrap();
cache.insert(genome, fitness);
}

pub fn cache_contains(&self, genome: &TradingStrategyGenome) -> bool {
let cache = self.fitness_cache.lock().unwrap();
cache.contains_key(genome)
}
}

impl FitnessFunction<TradingStrategyGenome, i64> for TradingStrategyFitnessFunction {
#[tracing::instrument(skip(self))]
#[tracing::instrument(skip(self, a))]
fn fitness_of(&self, a: &TradingStrategyGenome) -> i64 {
let strategy = self.to_phenotype(a);
if self.cache_contains(a) {
return *self.get_cache().get(a).unwrap();
}

let strategy = a.to_phenotype(&self.available_tickers, &self.available_technicals);
debug!("Evaluating fitness of strategy: {}", strategy);
let data = self.dataset.get(&a.interval).unwrap();
let data = panic::catch_unwind(|| {data.get_specific_tickers_and_technicals(&strategy.tickers, &strategy.technicals)});
let data = panic::catch_unwind(|| {
data.get_specific_tickers_and_technicals(&strategy.tickers, &strategy.technicals)
});
let data = match data {
Ok(data) => data,
Err(e) => {
Expand All @@ -271,8 +282,14 @@ impl FitnessFunction<TradingStrategyGenome, i64> for TradingStrategyFitnessFunct
if monthly_return.is_nan() || monthly_return.is_infinite() {
return i64::MIN;
}
debug!("Evaluated fitness: {:.2}%", monthly_return * 100.0);
(algorithm.get_monthly_return() * 10_000.0) as i64
debug!(
"Evaluated fitness of strategy {}: {:.2}%",
strategy,
monthly_return * 100.0
);
let fitness = (monthly_return * 10_000.0) as i64;
self.cache(a.clone(), fitness);
fitness
}

fn average(&self, a: &[i64]) -> i64 {
Expand Down Expand Up @@ -335,11 +352,15 @@ impl CrossoverOp<TradingStrategyGenome> for TradingStrategyCrossover {
child_technicals.push(*t2);
}
}
if child_technicals.iter().all(|&b| !b) {
let index = rng.gen_range(0..child_technicals.len());
child_technicals[index] = true;
}

let tech_count = child_technicals.iter().filter(|b| **b).count();
let tickers_count = child_tickers.iter().filter(|b| **b).count();
let child_d = (parent1.d as f64 * 0.5 + parent2.d as f64 * 0.5) as usize;
let mut child_n = (parent1.n as f64 * 0.5 + parent2.n as f64 * 0.5) as usize;
let child_d = (parent1.d + parent2.d) / 2;
let mut child_n = (parent1.n + parent2.n) / 2;
if child_n > child_d * tickers_count * tech_count {
child_n = child_d * tickers_count * tech_count;
}
Expand Down Expand Up @@ -414,35 +435,26 @@ impl MutationOp<TradingStrategyGenome> for TradingStrategyMutation {
R: Rng + Sized,
{
let mut new_genome = genome.clone();
if rng.gen_bool(self.mutation_rate) {
let (tickers, symbol) = TradingStrategyGenomeBuilder::new(
self.available_tickers.clone(),
self.available_intervals.clone(),
self.available_tickers.clone(),
self.max_depth,
self.max_n,
)
.tickers(rng);
new_genome.tickers = tickers;
new_genome.symbol = symbol;
let pos_of_symbol = self
.available_tickers
.iter()
.position(|s| s == &new_genome.symbol)
.unwrap();
new_genome.tickers[pos_of_symbol] = true;

for ticker in new_genome.tickers.iter_mut() {
if rng.gen_bool(self.mutation_rate) {
*ticker = !*ticker;
}
}

if rng.gen_bool(self.mutation_rate) {
let technicals = TradingStrategyGenomeBuilder::new(
self.available_tickers.clone(),
self.available_intervals.clone(),
self.available_tickers.clone(),
self.max_depth,
self.max_n,
)
.technicals(rng);
new_genome.technicals = technicals;
// Ensure symbol is still included
let pos_of_symbol = self
.available_tickers
.iter()
.position(|s| s == &new_genome.symbol)
.unwrap();
new_genome.tickers[pos_of_symbol] = true;

// Mutate individual bits in technicals
for technical in new_genome.technicals.iter_mut() {
if rng.gen_bool(self.mutation_rate) {
*technical = !*technical;
}
}

if rng.gen_bool(self.mutation_rate) {
Expand Down

0 comments on commit 92bbf9f

Please sign in to comment.