Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/candle-linear-regression'
Browse files Browse the repository at this point in the history
  • Loading branch information
vkomenda committed Aug 11, 2024
2 parents 9b1ba75 + 4b95be7 commit b708525
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 58 deletions.
69 changes: 35 additions & 34 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 44 additions & 13 deletions ml/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use forust_ml::{GradientBooster, Matrix};
use candle_core::{Device, Result, Tensor};
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW};
use std::error::Error;

#[no_mangle]
Expand All @@ -19,7 +20,7 @@ pub extern "C" fn run() {
}
}

fn greet_internal() -> Result<Vec<f64>, Box<dyn Error>> {
fn greet_internal() -> Result<Vec<f32>, Box<dyn Error>> {
// let csv_data = r#"survived,pclass,sex,age,sibsp,parch,fare,embarked,class,who,adult_male,deck,embark_town,alive,alone
// 0,3,male,22.0,1,0,7.25,S,Third,man,True,,Southampton,no,False
// 1,1,female,38.0,1,0,71.2833,C,First,woman,False,C,Cherbourg,yes,False
Expand All @@ -34,30 +35,60 @@ fn greet_internal() -> Result<Vec<f64>, Box<dyn Error>> {
-122.46,37.75,26.0,2192.0,438.0,954.0,456.0,4.5352,374200.0,NEAR BAY
-119.24,36.33,9.0,3289.0,621.0,1866.0,631.0,3.1599,95000.0,INLAND"#;

let mut data = Vec::new();
let mut y = Vec::new();
let mut features = Vec::new();
let mut target = Vec::new();

for line in csv_data.lines() {
let values: Vec<&str> = line.split(',').collect();

// Assuming the CSV structure: survived,pclass,age,sibsp,parch,fare
y.push(values[8].parse::<f64>().unwrap_or(f64::NAN));
target.push(values[8].parse::<f32>().unwrap_or(f32::NAN));

data.extend_from_slice(
features.extend_from_slice(
&(0..8)
.map(|col| values[col].parse::<f64>().unwrap_or(f64::NAN))
.map(|col| values[col].parse::<f32>().unwrap_or(f32::NAN))
.collect::<Vec<_>>(),
);
}

println!("{data:?}");
println!("{y:?}");
println!("{features:?}");
println!("{target:?}");

let matrix = Matrix::new(&data, csv_data.lines().count(), 8);
let num_samples = features.len();
let num_features = 8;

let mut model = GradientBooster::default().set_learning_rate(0.3);
model.fit_unweighted(&matrix, &y, None)?;
// Convert data to tensors
let features_tensor =
Tensor::from_slice(&features, &[num_samples, num_features], &Device::Cpu)?;
let target_tensor = Tensor::from_slice(&target, &[num_samples, 1], &Device::Cpu)?;

// Step 6: Define the linear regression model
let mut model = Linear::new(num_features, 1);

// Step 7: Set up the optimizer
let mut optimizer = AdamW::new(
vec![],
ParamsAdamW::default(), // OptimizerConfig::adam(0.01).build(model.parameters()),
);

// Step 8: Training Loop
let num_epochs = 100;
for epoch in 0..num_epochs {
let predictions = model.forward(&features_tensor)?;
let loss = mse_loss(&predictions, &target_tensor)?;
optimizer.step(&loss)?;

if epoch % 10 == 0 {
println!("Epoch {}: Loss = {:?}", epoch, loss);
}
}

// Step 9: Make and print predictions
let predictions = model.forward(&features_tensor)?;
println!(
"Model predictions (first 10): {:?}",
predictions.slice([..10])
);

let predictions = model.predict(&matrix, true);
Ok(predictions)
}
3 changes: 2 additions & 1 deletion zk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"

[dependencies]
anyhow = "1.0.86"
ff = "0.13"
rayon = "1.10.0"
tracing = "0.1.40"
zk-engine = { git = "https://github.com/ICME-Lab/zkEngine_dev.git" }
zk-engine = { git = "https://github.com/ICME-Lab/zkEngine_dev.git", branch = "main" }
Loading

0 comments on commit b708525

Please sign in to comment.