From b8480e76f8f64848ad07b656397b0e8f06b62ad0 Mon Sep 17 00:00:00 2001 From: Noah Clarkson Date: Tue, 19 Nov 2024 12:20:00 +1300 Subject: [PATCH] Refactor Algorithm and TestData for better structure and error handling - Introduce `AlgorithmResult` struct to encapsulate monthly return and accuracy. - Move `backtest` function into the `Algorithm` impl block and refactor its logic. - Update `Algorithm` struct to use `result: AlgorithmResult` instead of separate fields. - Add error handling in `TestData::new`, returning `Result`, and check for empty or unequal lengths of candles and predictions. - Refactor position handling logic in `TestData` for clarity and correctness. - Add new error variants `EmptyCandlesAndPredictions` and `UnequalCandlesAndPredictions` in `KryptoError`. - Add `predict` function in `src/algorithm/pls.rs` for PLS predictions. - Adjust `days_between` function in `date_utils` to accept `DateTime` parameters. - Update `main.rs` and other affected files to accommodate these changes. - Upgrade GitHub Actions `upload-artifact` from `v2` to `v4` in `rust.yml`. This refactoring improves code organization, error handling, and overall code readability. --- .github/workflows/rust.yml | 2 +- src/algorithm/algo.rs | 268 +++++++++++++++++++------------------ src/algorithm/pls.rs | 9 +- src/algorithm/test_data.rs | 180 ++++++++++++++----------- src/data/dataset.rs | 2 +- src/error.rs | 4 + src/main.rs | 8 +- src/util/date_utils.rs | 6 +- src/util/math_utils.rs | 10 ++ src/util/mod.rs | 1 + 10 files changed, 271 insertions(+), 219 deletions(-) create mode 100644 src/util/math_utils.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c8416a2..de9a533 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -70,7 +70,7 @@ jobs: - name: Upload Artifact (Windows) if: runner.os == 'Windows' - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: krypto-${{ runner.os }} path: artifacts/krypto.exe diff --git a/src/algorithm/algo.rs b/src/algorithm/algo.rs index a397e6b..b32cc4d 100644 --- a/src/algorithm/algo.rs +++ b/src/algorithm/algo.rs @@ -1,23 +1,23 @@ use std::fmt; -use linfa::traits::Predict as _; use linfa_pls::PlsRegression; -use ndarray::Array2; use tracing::{debug, info, instrument}; use crate::{ - algorithm::{pls::get_pls, test_data::TestData}, + algorithm::{ + pls::{get_pls, predict}, + test_data::TestData, + }, config::KryptoConfig, data::{candlestick::Candlestick, dataset::IntervalData}, error::KryptoError, - util::matrix_utils::normalize_by_columns, + util::{math_utils::median, matrix_utils::normalize_by_columns}, }; pub struct Algorithm { pub pls: PlsRegression, settings: AlgorithmSettings, - monthly_return: f64, - accuracy: f64, + result: AlgorithmResult, } impl Algorithm { @@ -27,27 +27,124 @@ impl Algorithm { settings: AlgorithmSettings, config: &KryptoConfig, ) -> Result { - let (monthly_return, accuracy) = backtest(dataset, settings.clone(), config)?; - let (features, labels, _) = get_overall_dataset(dataset, settings.clone()); + let result = Self::backtest(dataset, &settings, config)?; + let (features, labels, _) = Self::prepare_dataset(dataset, &settings); let pls = get_pls(features, labels, settings.n)?; Ok(Self { pls, settings, - monthly_return, - accuracy, + result, }) } + fn backtest( + dataset: &IntervalData, + settings: &AlgorithmSettings, + config: &KryptoConfig, + ) -> Result { + debug!("Running backtest"); + + let (features, labels, candles) = Self::prepare_dataset(dataset, settings); + let count = config.cross_validations; + let total_size = candles.len(); + let test_data_size = total_size / count; + + let mut test_results = Vec::with_capacity(count); + + for i in 0..count { + let start = i * test_data_size; + let end = match i == count - 1 { + true => total_size, + false => (i + 1) * test_data_size, + }; + + let test_features = &features[start..end]; + let test_candles = &candles[start..end]; + let train_features = [&features[..start], &features[end..]].concat(); + let train_labels = [&labels[..start], &labels[end..]].concat(); + + let pls = get_pls(train_features, train_labels, settings.n)?; + let predictions = predict(&pls, test_features); + + let test_data = TestData::new(predictions, test_candles.to_vec(), config)?; + debug!( + "Cross-validation {} ({}-{}): {}", + i + 1, + start, + end, + test_data + ); + test_results.push(test_data); + } + + let median_return = median( + &test_results + .iter() + .map(|d| d.monthly_return) + .collect::>(), + ); + let median_accuracy = median(&test_results.iter().map(|d| d.accuracy).collect::>()); + let result = AlgorithmResult::new(median_return, median_accuracy); + info!("Backtest result: {}", result); + Ok(result) + } + + #[instrument(skip(dataset))] + fn prepare_dataset( + dataset: &IntervalData, + settings: &AlgorithmSettings, + ) -> (Vec>, Vec, Vec) { + let records = dataset.get_records(); + let normalized_predictors = normalize_by_columns(records) + .into_iter() + .map(|row| { + row.into_iter() + .map(|v| if v.is_nan() { 0.0 } else { v }) + .collect() + }) + .collect::>>(); + + let features = normalized_predictors + .windows(settings.depth) + .map(|window| window.iter().flatten().cloned().collect()) + .collect::>>(); + let features = features[..features.len() - 1].to_vec(); + + let symbol_data = dataset + .get(&settings.symbol) + .expect("Symbol not found in dataset"); + + let labels: Vec = symbol_data + .get_labels() + .iter() + .skip(settings.depth) + .map(|&v| if v.is_nan() { 1.0 } else { v }) + .collect(); + + let candles: Vec = symbol_data + .get_candles() + .iter() + .skip(settings.depth) + .cloned() + .collect(); + + debug!("Features shape: {}x{}", features.len(), features[0].len()); + debug!("Labels count: {}", labels.len()); + debug!("Candles count: {}", candles.len()); + + (features, labels, candles) + } + pub fn get_symbol(&self) -> &str { &self.settings.symbol } - pub fn get_monthly_return(&self) -> &f64 { - &self.monthly_return + pub fn get_monthly_return(&self) -> f64 { + self.result.monthly_return } pub fn get_accuracy(&self) -> f64 { - self.accuracy + self.result.accuracy } pub fn get_n_components(&self) -> usize { @@ -63,126 +160,12 @@ impl fmt::Display for Algorithm { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "Algorithm: ({}) | Monthly Return: {:.2} | Accuracy: {:.2}%", - self.settings, - self.monthly_return * 100.0, - self.accuracy * 100.0 + "Algorithm: ({}) | Result: ({})", + self.settings, self.result ) } } -#[instrument(skip(dataset, config, settings))] -fn backtest( - dataset: &IntervalData, - settings: AlgorithmSettings, - config: &KryptoConfig, -) -> Result<(f64, f64), KryptoError> { - info!("Running backtest"); - let (features, labels, candles) = get_overall_dataset(dataset, settings.clone()); - let count = config.cross_validations; - let total_size = candles.len(); - let test_data_size = (total_size as f64 / count as f64).floor() as usize - 1; - let mut test_datas = Vec::new(); - for i in 0..count { - let start = i * test_data_size; - let end = if i == count - 1 { - total_size - } else { - (i + 1) * test_data_size - }; - debug!("Start: {} | End: {}", start, end); - let mut train_features = features.clone(); - let test_features: Vec> = train_features.drain(start..end).collect(); - let mut train_labels = labels.clone(); - train_labels.drain(start..end); - let test_candles = candles.clone().drain(start..end).collect(); - let pls = get_pls(train_features, train_labels, settings.n)?; - let predictions = get_predictions(pls, test_features); - debug!("Running cross validation: {}/{}", i + 1, count); - let test_data = TestData::new(predictions, test_candles, config); - debug!("Cross Validation {}: {}", i + 1, test_data); - test_datas.push(test_data); - } - let returns = test_datas - .iter() - .map(|d| d.monthly_return) - .collect::>(); - let accuracies = test_datas.iter().map(|d| d.accuracy).collect::>(); - let median_return = returns[returns.len() / 2]; - let median_accuracy = accuracies[accuracies.len() / 2]; - info!( - "Median Monthly Return: {:.2} | Median Accuracy: {:.2}%", - median_return * 100.0, - median_accuracy * 100.0 - ); - Ok((median_return, median_accuracy)) -} - -fn get_predictions(pls: PlsRegression, features: Vec>) -> Vec { - let features = Array2::from_shape_vec( - (features.len(), features[0].len()), - features.iter().flatten().cloned().collect(), - ) - .unwrap(); - pls.predict(&features).as_slice().unwrap().to_vec() -} - -#[instrument(skip(dataset))] -fn get_overall_dataset( - dataset: &IntervalData, - settings: AlgorithmSettings, -) -> (Vec>, Vec, Vec) { - let records = dataset.get_records(); - let predictors = normalize_by_columns(records); - // Set all NaN values to 0 - let predictors: Vec> = predictors - .iter() - .map(|r| { - r.iter() - .map(|v| { - if v.is_nan() { - debug!("Found NaN value"); - 0.0 - } else { - *v - } - }) - .collect() - }) - .collect(); - let features: Vec> = predictors - .windows(settings.depth) - .map(|w| w.iter().flatten().copied().collect::>()) - .collect(); - // Remove the last features row to match the labels length - let features: Vec> = features.iter().take(features.len() - 1).cloned().collect(); - let labels: Vec = dataset - .get(&settings.symbol) - .unwrap() - .get_labels() - .iter() - .skip(settings.depth) - .cloned() - .collect(); - // Set NaN values to 1 - let labels: Vec = labels - .iter() - .map(|v| if v.is_nan() { 1.0 } else { *v }) - .collect(); - let candles: Vec = dataset - .get(&settings.symbol) - .unwrap() - .get_candles() - .iter() - .skip(settings.depth) - .cloned() - .collect(); - debug!("Features Shape: {}x{}", features.len(), features[0].len()); - debug!("Labels Shape: {}", labels.len()); - debug!("Candles Shape: {}", candles.len()); - (features, labels, candles) -} - #[derive(Debug, Clone, PartialEq)] pub struct AlgorithmSettings { pub n: usize, @@ -204,8 +187,33 @@ impl fmt::Display for AlgorithmSettings { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "symbol: {} | Depth: {} | N Components: {}", + "Symbol: {} | Depth: {} | Components: {}", self.symbol, self.depth, self.n ) } } + +pub struct AlgorithmResult { + pub monthly_return: f64, + pub accuracy: f64, +} + +impl AlgorithmResult { + pub fn new(monthly_return: f64, accuracy: f64) -> Self { + Self { + monthly_return, + accuracy, + } + } +} + +impl fmt::Display for AlgorithmResult { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Median Monthly Return: {:.2}% | Median Accuracy: {:.2}%", + self.monthly_return * 100.0, + self.accuracy * 100.0 + ) + } +} diff --git a/src/algorithm/pls.rs b/src/algorithm/pls.rs index b02c4c5..4aa08cb 100644 --- a/src/algorithm/pls.rs +++ b/src/algorithm/pls.rs @@ -1,4 +1,4 @@ -use linfa::traits::Fit; +use linfa::traits::{Fit, Predict as _}; use linfa_pls::PlsRegression; use ndarray::Array2; @@ -22,3 +22,10 @@ pub fn get_pls( .map_err(|e| KryptoError::FitError(e.to_string()))?; Ok(pls) } + +pub fn predict(pls: &PlsRegression, features: &[Vec]) -> Vec { + let flat_features: Vec = features.iter().flatten().cloned().collect(); + let array_features = Array2::from_shape_vec((features.len(), features[0].len()), flat_features) + .expect("Failed to create feature array"); + pls.predict(&array_features).as_slice().unwrap().to_vec() +} diff --git a/src/algorithm/test_data.rs b/src/algorithm/test_data.rs index 6fd254f..195fad1 100644 --- a/src/algorithm/test_data.rs +++ b/src/algorithm/test_data.rs @@ -1,9 +1,12 @@ use std::fmt; use crate::{ - config::KryptoConfig, data::candlestick::Candlestick, util::date_utils::days_between_datetime, + config::KryptoConfig, data::candlestick::Candlestick, error::KryptoError, + util::date_utils::days_between, }; +const STARTING_CASH: f64 = 1000.0; + pub struct TestData { pub cash_history: Vec, pub accuracy: f64, @@ -11,49 +14,83 @@ pub struct TestData { } impl TestData { - pub fn new(predictions: Vec, candles: Vec, config: &KryptoConfig) -> Self { - let days = - days_between_datetime(candles[0].open_time, candles[candles.len() - 1].close_time); - let mut position = Position::None; - let mut cash = 1000.0; - let mut correct = 0; - let mut incorrect = 0; - let mut cash_history = vec![cash]; - for i in 0..predictions.len() { - let prediction = predictions[i].signum(); - let position_now = Position::from_f64(prediction, candles[i].close); - if position == Position::None { - position = position_now.clone(); - } - if position != position_now { - let return_now = position.get_return(candles[i].close); - cash += cash * return_now; - cash -= cash * config.fee.unwrap_or_default(); - position = position_now; - if return_now > 0.0 { - correct += 1; - } else { - incorrect += 1; + pub fn new( + predictions: Vec, + candles: Vec, + config: &KryptoConfig, + ) -> Result { + if candles.is_empty() || predictions.is_empty() { + return Err(KryptoError::EmptyCandlesAndPredictions); + } + + if candles.len() != predictions.len() { + return Err(KryptoError::UnequalCandlesAndPredictions); + } + + let fee = config.fee.unwrap_or(0.0); + let days = days_between( + candles.first().unwrap().open_time, + candles.last().unwrap().close_time, + ); + let mut position: Option = None; + let mut inner = InnerTestData::default(); + + for (prediction, candle) in predictions.iter().zip(candles.iter()) { + let prediction_sign = prediction.signum(); + + let new_position = match prediction_sign { + p if p > 0.0 => Some(Position::Long(candle.close)), + p if p < 0.0 => Some(Position::Short(candle.close)), + _ => None, + }; + + // Check if we need to close the existing position + if position.is_some() && position != new_position { + // Close the existing position + if let Some(ref pos) = position { + inner.close_position(pos, candle, fee); } - cash_history.push(cash); + + position = new_position; + } else if position.is_none() { + // Open a new position if we don't have one + position = new_position.clone(); } + + // No position change; continue holding or staying out } - let months = days as f64 / 30.0; - let accuracy = correct as f64 / (correct + incorrect) as f64; - let monthly_return = (cash / 1000.0).powf(1.0 / months) - 1.0; - Self { - cash_history, + + // Close any remaining open position at the end + if let Some(ref pos) = position { + inner.close_position(pos, candles.last().unwrap(), fee); + } + + let months = days as f64 / 30.44; + let total_trades = inner.correct + inner.incorrect; + let accuracy = if total_trades > 0 { + inner.correct as f64 / total_trades as f64 + } else { + 0.0 + }; + let monthly_return = if months > 0.0 { + (inner.cash / 1000.0).powf(1.0 / months) - 1.0 + } else { + 0.0 + }; + + Ok(Self { + cash_history: inner.cash_history, accuracy, monthly_return, - } + }) } } impl fmt::Display for TestData { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "Accuracy: {:.2} | Monthly Return: {:.2}%", + "Accuracy: {:.2}% | Monthly Return: {:.2}%", self.accuracy * 100.0, self.monthly_return * 100.0 ) @@ -64,66 +101,55 @@ impl fmt::Display for TestData { enum Position { Long(f64), Short(f64), - None, } impl Position { - fn get_return(&self, close: f64) -> f64 { - match self { - Position::Long(entry) => (close - entry) / entry, - Position::Short(entry) => (entry - close) / entry, - Position::None => 0.0, - } - } - - fn from_f64(value: f64, open_price: f64) -> Self { - if value > 0.0 { - Position::Long(open_price) - } else if value < 0.0 { - Position::Short(open_price) - } else { - Position::None + fn get_return(&self, close_price: f64) -> f64 { + match *self { + Position::Long(entry_price) => (close_price - entry_price) / entry_price, + Position::Short(entry_price) => (entry_price - close_price) / entry_price, } } } -impl fmt::Display for Position { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Position::Long(entry) => write!(f, "Long: ${}", entry), - Position::Short(entry) => write!(f, "Short: ${}", entry), - Position::None => write!(f, "None"), - } +impl PartialEq for Position { + fn eq(&self, other: &Self) -> bool { + matches!( + (self, other), + (Position::Long(_), Position::Long(_)) | (Position::Short(_), Position::Short(_)) + ) } } -impl PartialEq for f64 { - fn eq(&self, other: &Position) -> bool { - match other { - Position::Long(_) => self > &0.0, - Position::Short(_) => self < &0.0, - Position::None => self == &0.0, - } - } +struct InnerTestData { + cash: f64, + correct: u32, + incorrect: u32, + cash_history: Vec, } -impl PartialEq for Position { - fn eq(&self, other: &f64) -> bool { - match self { - Position::Long(_) => other > &0.0, - Position::Short(_) => other < &0.0, - Position::None => other == &0.0, +impl InnerTestData { + fn close_position(&mut self, position: &Position, candle: &Candlestick, fee: f64) { + let return_now = position.get_return(candle.close); + self.cash += self.cash * return_now; + self.cash -= self.cash * fee; + self.cash_history.push(self.cash); + + if return_now > 0.0 { + self.correct += 1; + } else { + self.incorrect += 1; } } } -impl PartialEq for Position { - fn eq(&self, other: &Position) -> bool { - // Simply compare the type of the enum - match self { - Position::Long(_) => matches!(other, Position::Long(_)), - Position::Short(_) => matches!(other, Position::Short(_)), - Position::None => matches!(other, Position::None), +impl Default for InnerTestData { + fn default() -> Self { + Self { + cash: STARTING_CASH, + correct: 0, + incorrect: 0, + cash_history: vec![STARTING_CASH], } } } diff --git a/src/data/dataset.rs b/src/data/dataset.rs index c3afa79..145a821 100644 --- a/src/data/dataset.rs +++ b/src/data/dataset.rs @@ -125,7 +125,7 @@ impl IntervalData { } records.push(record); } - records + records } } diff --git a/src/error.rs b/src/error.rs index c3b317a..1a4e4c4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,6 +22,10 @@ pub enum KryptoError { FitError(String), #[error("CSV Error: {0}")] CsvError(String), + #[error("Candles and predictions cannot be empty")] + EmptyCandlesAndPredictions, + #[error("Candles and predictions must be of the same length")] + UnequalCandlesAndPredictions, } #[derive(Debug, Clone, PartialEq)] diff --git a/src/main.rs b/src/main.rs index 1351abd..fbbcc8e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,15 +40,15 @@ fn run() -> Result<(), KryptoError> { info!("Interval: {}", interval); for symbol in interval_data.keys() { info!("Symbol: {}", symbol); - for n in 1..25 { - for depth in 1..25 { + for n in 1..50 { + for depth in 1..50 { if n >= depth * TECHNICAL_COUNT { continue; } let settings = AlgorithmSettings::new(n, depth, symbol); let algorithm = Algorithm::load(interval_data, settings, &config)?; - if algorithm.get_monthly_return() > &best_return { - best_return = *algorithm.get_monthly_return(); + if algorithm.get_monthly_return() > best_return { + best_return = algorithm.get_monthly_return(); best_algorithm = Some(i); info!("New best algorithm: {}", &algorithm); } diff --git a/src/util/date_utils.rs b/src/util/date_utils.rs index 5d2e0c5..08f80fc 100644 --- a/src/util/date_utils.rs +++ b/src/util/date_utils.rs @@ -34,10 +34,6 @@ pub fn get_timestamps( Ok(timestamps) } -pub fn days_between(start: NaiveDate, end: NaiveDate) -> i64 { - (end - start).num_days() -} - -pub fn days_between_datetime(start: DateTime, end: DateTime) -> i64 { +pub fn days_between(start: DateTime, end: DateTime) -> i64 { (end - start).num_days() } diff --git a/src/util/math_utils.rs b/src/util/math_utils.rs new file mode 100644 index 0000000..7e5559e --- /dev/null +++ b/src/util/math_utils.rs @@ -0,0 +1,10 @@ +pub fn median(values: &[f64]) -> f64 { + let mut sorted_values = values.to_vec(); + sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mid = sorted_values.len() / 2; + if sorted_values.len() % 2 == 0 { + (sorted_values[mid - 1] + sorted_values[mid]) / 2.0 + } else { + sorted_values[mid] + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index bc7694c..50f259d 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,2 +1,3 @@ pub mod date_utils; pub mod matrix_utils; +pub mod math_utils; \ No newline at end of file