Skip to content

Commit

Permalink
Refactor Algorithm and TestData for better structure and error handling
Browse files Browse the repository at this point in the history
- Introduce `AlgorithmResult` struct to encapsulate monthly return and accuracy.
- Move `backtest` function into the `Algorithm` impl block and refactor its logic.
- Update `Algorithm` struct to use `result: AlgorithmResult` instead of separate fields.
- Add error handling in `TestData::new`, returning `Result<Self, KryptoError>`, and check for empty or unequal lengths of candles and predictions.
- Refactor position handling logic in `TestData` for clarity and correctness.
- Add new error variants `EmptyCandlesAndPredictions` and `UnequalCandlesAndPredictions` in `KryptoError`.
- Add `predict` function in `src/algorithm/pls.rs` for PLS predictions.
- Adjust `days_between` function in `date_utils` to accept `DateTime<Utc>` parameters.
- Update `main.rs` and other affected files to accommodate these changes.
- Upgrade GitHub Actions `upload-artifact` from `v2` to `v4` in `rust.yml`.

This refactoring improves code organization, error handling, and overall code readability.
  • Loading branch information
noahbclarkson committed Nov 18, 2024
1 parent 793a1c6 commit b8480e7
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 219 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:

- name: Upload Artifact (Windows)
if: runner.os == 'Windows'
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: krypto-${{ runner.os }}
path: artifacts/krypto.exe
Expand Down
268 changes: 138 additions & 130 deletions src/algorithm/algo.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
use std::fmt;

use linfa::traits::Predict as _;
use linfa_pls::PlsRegression;
use ndarray::Array2;
use tracing::{debug, info, instrument};

use crate::{
algorithm::{pls::get_pls, test_data::TestData},
algorithm::{
pls::{get_pls, predict},
test_data::TestData,
},
config::KryptoConfig,
data::{candlestick::Candlestick, dataset::IntervalData},
error::KryptoError,
util::matrix_utils::normalize_by_columns,
util::{math_utils::median, matrix_utils::normalize_by_columns},
};

pub struct Algorithm {
pub pls: PlsRegression<f64>,
settings: AlgorithmSettings,
monthly_return: f64,
accuracy: f64,
result: AlgorithmResult,
}

impl Algorithm {
Expand All @@ -27,27 +27,124 @@ impl Algorithm {
settings: AlgorithmSettings,
config: &KryptoConfig,
) -> Result<Self, KryptoError> {
let (monthly_return, accuracy) = backtest(dataset, settings.clone(), config)?;
let (features, labels, _) = get_overall_dataset(dataset, settings.clone());
let result = Self::backtest(dataset, &settings, config)?;
let (features, labels, _) = Self::prepare_dataset(dataset, &settings);
let pls = get_pls(features, labels, settings.n)?;
Ok(Self {
pls,
settings,
monthly_return,
accuracy,
result,
})
}

fn backtest(
dataset: &IntervalData,
settings: &AlgorithmSettings,
config: &KryptoConfig,
) -> Result<AlgorithmResult, KryptoError> {
debug!("Running backtest");

let (features, labels, candles) = Self::prepare_dataset(dataset, settings);
let count = config.cross_validations;
let total_size = candles.len();
let test_data_size = total_size / count;

let mut test_results = Vec::with_capacity(count);

for i in 0..count {
let start = i * test_data_size;
let end = match i == count - 1 {
true => total_size,
false => (i + 1) * test_data_size,
};

let test_features = &features[start..end];
let test_candles = &candles[start..end];
let train_features = [&features[..start], &features[end..]].concat();
let train_labels = [&labels[..start], &labels[end..]].concat();

let pls = get_pls(train_features, train_labels, settings.n)?;
let predictions = predict(&pls, test_features);

let test_data = TestData::new(predictions, test_candles.to_vec(), config)?;
debug!(
"Cross-validation {} ({}-{}): {}",
i + 1,
start,
end,
test_data
);
test_results.push(test_data);
}

let median_return = median(
&test_results
.iter()
.map(|d| d.monthly_return)
.collect::<Vec<_>>(),
);
let median_accuracy = median(&test_results.iter().map(|d| d.accuracy).collect::<Vec<_>>());
let result = AlgorithmResult::new(median_return, median_accuracy);
info!("Backtest result: {}", result);
Ok(result)
}

#[instrument(skip(dataset))]
fn prepare_dataset(
dataset: &IntervalData,
settings: &AlgorithmSettings,
) -> (Vec<Vec<f64>>, Vec<f64>, Vec<Candlestick>) {
let records = dataset.get_records();
let normalized_predictors = normalize_by_columns(records)
.into_iter()
.map(|row| {
row.into_iter()
.map(|v| if v.is_nan() { 0.0 } else { v })
.collect()
})
.collect::<Vec<Vec<f64>>>();

let features = normalized_predictors
.windows(settings.depth)
.map(|window| window.iter().flatten().cloned().collect())
.collect::<Vec<Vec<f64>>>();
let features = features[..features.len() - 1].to_vec();

let symbol_data = dataset
.get(&settings.symbol)
.expect("Symbol not found in dataset");

let labels: Vec<f64> = symbol_data
.get_labels()
.iter()
.skip(settings.depth)
.map(|&v| if v.is_nan() { 1.0 } else { v })
.collect();

let candles: Vec<Candlestick> = symbol_data
.get_candles()
.iter()
.skip(settings.depth)
.cloned()
.collect();

debug!("Features shape: {}x{}", features.len(), features[0].len());
debug!("Labels count: {}", labels.len());
debug!("Candles count: {}", candles.len());

(features, labels, candles)
}

pub fn get_symbol(&self) -> &str {
&self.settings.symbol
}

pub fn get_monthly_return(&self) -> &f64 {
&self.monthly_return
pub fn get_monthly_return(&self) -> f64 {
self.result.monthly_return
}

pub fn get_accuracy(&self) -> f64 {
self.accuracy
self.result.accuracy
}

pub fn get_n_components(&self) -> usize {
Expand All @@ -63,126 +160,12 @@ impl fmt::Display for Algorithm {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Algorithm: ({}) | Monthly Return: {:.2} | Accuracy: {:.2}%",
self.settings,
self.monthly_return * 100.0,
self.accuracy * 100.0
"Algorithm: ({}) | Result: ({})",
self.settings, self.result
)
}
}

#[instrument(skip(dataset, config, settings))]
fn backtest(
dataset: &IntervalData,
settings: AlgorithmSettings,
config: &KryptoConfig,
) -> Result<(f64, f64), KryptoError> {
info!("Running backtest");
let (features, labels, candles) = get_overall_dataset(dataset, settings.clone());
let count = config.cross_validations;
let total_size = candles.len();
let test_data_size = (total_size as f64 / count as f64).floor() as usize - 1;
let mut test_datas = Vec::new();
for i in 0..count {
let start = i * test_data_size;
let end = if i == count - 1 {
total_size
} else {
(i + 1) * test_data_size
};
debug!("Start: {} | End: {}", start, end);
let mut train_features = features.clone();
let test_features: Vec<Vec<f64>> = train_features.drain(start..end).collect();
let mut train_labels = labels.clone();
train_labels.drain(start..end);
let test_candles = candles.clone().drain(start..end).collect();
let pls = get_pls(train_features, train_labels, settings.n)?;
let predictions = get_predictions(pls, test_features);
debug!("Running cross validation: {}/{}", i + 1, count);
let test_data = TestData::new(predictions, test_candles, config);
debug!("Cross Validation {}: {}", i + 1, test_data);
test_datas.push(test_data);
}
let returns = test_datas
.iter()
.map(|d| d.monthly_return)
.collect::<Vec<f64>>();
let accuracies = test_datas.iter().map(|d| d.accuracy).collect::<Vec<f64>>();
let median_return = returns[returns.len() / 2];
let median_accuracy = accuracies[accuracies.len() / 2];
info!(
"Median Monthly Return: {:.2} | Median Accuracy: {:.2}%",
median_return * 100.0,
median_accuracy * 100.0
);
Ok((median_return, median_accuracy))
}

fn get_predictions(pls: PlsRegression<f64>, features: Vec<Vec<f64>>) -> Vec<f64> {
let features = Array2::from_shape_vec(
(features.len(), features[0].len()),
features.iter().flatten().cloned().collect(),
)
.unwrap();
pls.predict(&features).as_slice().unwrap().to_vec()
}

#[instrument(skip(dataset))]
fn get_overall_dataset(
dataset: &IntervalData,
settings: AlgorithmSettings,
) -> (Vec<Vec<f64>>, Vec<f64>, Vec<Candlestick>) {
let records = dataset.get_records();
let predictors = normalize_by_columns(records);
// Set all NaN values to 0
let predictors: Vec<Vec<f64>> = predictors
.iter()
.map(|r| {
r.iter()
.map(|v| {
if v.is_nan() {
debug!("Found NaN value");
0.0
} else {
*v
}
})
.collect()
})
.collect();
let features: Vec<Vec<f64>> = predictors
.windows(settings.depth)
.map(|w| w.iter().flatten().copied().collect::<Vec<f64>>())
.collect();
// Remove the last features row to match the labels length
let features: Vec<Vec<f64>> = features.iter().take(features.len() - 1).cloned().collect();
let labels: Vec<f64> = dataset
.get(&settings.symbol)
.unwrap()
.get_labels()
.iter()
.skip(settings.depth)
.cloned()
.collect();
// Set NaN values to 1
let labels: Vec<f64> = labels
.iter()
.map(|v| if v.is_nan() { 1.0 } else { *v })
.collect();
let candles: Vec<Candlestick> = dataset
.get(&settings.symbol)
.unwrap()
.get_candles()
.iter()
.skip(settings.depth)
.cloned()
.collect();
debug!("Features Shape: {}x{}", features.len(), features[0].len());
debug!("Labels Shape: {}", labels.len());
debug!("Candles Shape: {}", candles.len());
(features, labels, candles)
}

#[derive(Debug, Clone, PartialEq)]
pub struct AlgorithmSettings {
pub n: usize,
Expand All @@ -204,8 +187,33 @@ impl fmt::Display for AlgorithmSettings {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"symbol: {} | Depth: {} | N Components: {}",
"Symbol: {} | Depth: {} | Components: {}",
self.symbol, self.depth, self.n
)
}
}

pub struct AlgorithmResult {
pub monthly_return: f64,
pub accuracy: f64,
}

impl AlgorithmResult {
pub fn new(monthly_return: f64, accuracy: f64) -> Self {
Self {
monthly_return,
accuracy,
}
}
}

impl fmt::Display for AlgorithmResult {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Median Monthly Return: {:.2}% | Median Accuracy: {:.2}%",
self.monthly_return * 100.0,
self.accuracy * 100.0
)
}
}
9 changes: 8 additions & 1 deletion src/algorithm/pls.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use linfa::traits::Fit;
use linfa::traits::{Fit, Predict as _};
use linfa_pls::PlsRegression;
use ndarray::Array2;

Expand All @@ -22,3 +22,10 @@ pub fn get_pls(
.map_err(|e| KryptoError::FitError(e.to_string()))?;
Ok(pls)
}

pub fn predict(pls: &PlsRegression<f64>, features: &[Vec<f64>]) -> Vec<f64> {
let flat_features: Vec<f64> = features.iter().flatten().cloned().collect();
let array_features = Array2::from_shape_vec((features.len(), features[0].len()), flat_features)
.expect("Failed to create feature array");
pls.predict(&array_features).as_slice().unwrap().to_vec()
}
Loading

0 comments on commit b8480e7

Please sign in to comment.