Skip to content

Commit

Permalink
Merge pull request #26 from alexandrainst/send-futures
Browse files Browse the repository at this point in the history
Make the async traits `Send`
  • Loading branch information
quackzar authored Jun 20, 2024
2 parents af302b5 + 604bf04 commit 9094a1d
Show file tree
Hide file tree
Showing 16 changed files with 117 additions and 114 deletions.
4 changes: 2 additions & 2 deletions pycare/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 17 additions & 16 deletions pycare/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,48 @@
use pyo3::{prelude::*, types::PyTuple, exceptions::PyIOError};
use pyo3::{exceptions::PyIOError, prelude::*, types::PyTuple};

use std::fs::File;
use wecare::*;
use std::path::Path;

#[pyclass]
struct Engine(Option<AdderEngine>);
struct Engine(Option<SpdzEngine>);

/// Setup a MPC addition engine connected to the given sockets.
#[pyfunction]
#[pyo3(signature = (path_to_pre, my_addr, *others))]

fn setup(path_to_pre: &str, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult<Engine> {
let others : Vec<_> = others.iter().map(|x| x.extract::<String>().unwrap().clone())
let others: Vec<_> = others
.iter()
.map(|x| x.extract::<String>().unwrap().clone())
.collect();
let file_name = Path::new(path_to_pre);
match setup_engine(my_addr, &others, file_name) {
let mut file = File::open(path_to_pre).unwrap();
match AdderEngine::spdz(my_addr, &others, &mut file) {
Ok(e) => Ok(Engine(Some(e))),
Err(e) => Err(PyIOError::new_err(e.0))
Err(e) => Err(PyIOError::new_err(e.0)),
}
}

/// Calculate and save the preprocessing
#[pyfunction]
#[pyo3(signature = (number_of_shares, paths_to_pre))]
fn preproc( number_of_shares: usize, paths_to_pre: &str){
let paths_to_pre = paths_to_pre.split(",").collect();
do_preproc(paths_to_pre, vec![number_of_shares, number_of_shares]);
fn preproc(number_of_shares: usize, paths_to_pre: &str) {
let mut files: Vec<File> = paths_to_pre
.split(",")
.map(|p| File::open(p).unwrap())
.collect();
do_preproc(&mut files, vec![number_of_shares, number_of_shares]);
}


#[pymethods]
impl Engine {

/// Run a sum procedure in which each party supplies a double floating point
fn sum(&mut self, a: f64) -> f64 {
mpc_sum(self.0.as_mut().unwrap(), &[a]).unwrap()[0]
self.0.as_mut().unwrap().mpc_sum(&[a]).unwrap()[0]
}

/// Run a sum procedure in which each party supplies a double floating point
fn sum_many(&mut self, a: Vec<f64>) -> Vec<f64> {
mpc_sum(self.0.as_mut().unwrap(), &a).unwrap()
self.0.as_mut().unwrap().mpc_sum(&a).unwrap()
}

/// takedown engine
Expand All @@ -49,7 +52,6 @@ impl Engine {
}
}


/// A Python module implemented in Rust.
#[pymodule]
fn caring(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand All @@ -58,4 +60,3 @@ fn caring(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Engine>()?;
Ok(())
}

2 changes: 1 addition & 1 deletion src/algebra/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<F> Vector<F> {
Self(Box::new(v))
}

pub fn len(&self) -> usize {
pub fn size(&self) -> usize {
self.0.len()
}

Expand Down
4 changes: 2 additions & 2 deletions src/algebra/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl<G: Field> Polynomial<G> {
impl<G: Send + Sync> Polynomial<G> {
pub fn degree(&self) -> usize {
// a0 + a1x1 is degree(1)
self.0.len() - 1
self.0.size() - 1
}
}

Expand Down Expand Up @@ -101,7 +101,7 @@ impl<
{
pub fn mult(&self, other: &Self) -> Polynomial<F> {
// degree is length - 1.
let n = self.0.len() + other.0.len();
let n = self.0.size() + other.0.size();
let iter = self
.0
.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/algebra/rayon.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
use std::ops::{Add, AddAssign, Div, DivAssign, Sub, SubAssign};

use rayon::prelude::*;

Expand Down
6 changes: 4 additions & 2 deletions src/marker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Experimental module to try 'marking' different kind of shares in a program using types.
mod exptree;

use rand::RngCore;
Expand Down Expand Up @@ -41,7 +43,7 @@ impl<S> Unverified<S> {
}
}

impl<'ctx, S: InteractiveShared<'ctx>> Verified<S> {
impl<S: InteractiveShared> Verified<S> {
pub async fn open(
self,
ctx: &mut S::Context,
Expand All @@ -61,7 +63,7 @@ impl<'ctx, S: InteractiveShared<'ctx>> Verified<S> {
}
}

impl<'ctx, S: InteractiveShared<'ctx>> Unverified<S> {
impl<S: InteractiveShared> Unverified<S> {
pub async fn share_symmetric(
val: S::Value,
ctx: &mut S::Context,
Expand Down
50 changes: 25 additions & 25 deletions src/net/agency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use std::{error::Error, marker::PhantomData};
use futures::Future;
use itertools::Itertools;

pub trait Broadcast {
pub trait Broadcast: Send {
type BroadcastError: Error + Send + 'static;
// type Error: Error + 'static;

Expand All @@ -38,7 +38,7 @@ pub trait Broadcast {
fn broadcast(
&mut self,
msg: &(impl serde::Serialize + Sync),
) -> impl std::future::Future<Output = Result<(), Self::BroadcastError>>;
) -> impl std::future::Future<Output = Result<(), Self::BroadcastError>> + Send;

/// Broadcast a message to all parties and await their messages
/// Messages are ordered by their index.
Expand All @@ -50,17 +50,17 @@ pub trait Broadcast {
fn symmetric_broadcast<T>(
&mut self,
msg: T,
) -> impl Future<Output = Result<Vec<T>, Self::BroadcastError>>
) -> impl Future<Output = Result<Vec<T>, Self::BroadcastError>> + Send
where
T: serde::Serialize + serde::de::DeserializeOwned + Sync;
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync;

/// Receive a message from a party
///
/// Returns: a message from the given party or an error
fn recv_from<T: serde::de::DeserializeOwned>(
fn recv_from<T: serde::de::DeserializeOwned + Send>(
&mut self,
idx: usize,
) -> impl Future<Output = Result<T, Self::BroadcastError>>;
) -> impl Future<Output = Result<T, Self::BroadcastError>> + Send;

/// Size of the broadcasting network including yourself,
/// as such there is n-1 outgoing connections
Expand All @@ -73,24 +73,24 @@ impl<'a, B: Broadcast> Broadcast for &'a mut B {
fn broadcast(
&mut self,
msg: &(impl serde::Serialize + Sync),
) -> impl std::future::Future<Output = Result<(), Self::BroadcastError>> {
) -> impl std::future::Future<Output = Result<(), Self::BroadcastError>> + Send {
(**self).broadcast(msg)
}

fn symmetric_broadcast<T>(
&mut self,
msg: T,
) -> impl Future<Output = Result<Vec<T>, Self::BroadcastError>>
) -> impl Future<Output = Result<Vec<T>, Self::BroadcastError>> + Send
where
T: serde::Serialize + serde::de::DeserializeOwned + Sync,
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync,
{
(**self).symmetric_broadcast(msg)
}

fn recv_from<T: serde::de::DeserializeOwned>(
fn recv_from<T: serde::de::DeserializeOwned + Send>(
&mut self,
idx: usize,
) -> impl Future<Output = Result<T, Self::BroadcastError>> {
) -> impl Future<Output = Result<T, Self::BroadcastError>> + Send {
(**self).recv_from(idx)
}

Expand All @@ -113,8 +113,8 @@ pub trait Unicast {
/// * `msgs`: Messages to send
fn unicast(
&mut self,
msgs: &[impl serde::Serialize + Sync],
) -> impl std::future::Future<Output = Result<(), Self::UnicastError>>;
msgs: &[impl serde::Serialize + Send + Sync],
) -> impl std::future::Future<Output = Result<(), Self::UnicastError>> + Send;

/// Unicast a message to each party and await their messages
/// Messages are supposed to be in order, meaning message `i`
Expand All @@ -124,18 +124,18 @@ pub trait Unicast {
fn symmetric_unicast<T>(
&mut self,
msgs: Vec<T>,
) -> impl Future<Output = Result<Vec<T>, Self::UnicastError>>
) -> impl Future<Output = Result<Vec<T>, Self::UnicastError>> + Send
where
T: serde::Serialize + serde::de::DeserializeOwned + Sync;
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync;

/// Receive a message for each party.
///
/// Asymmetric, waiting
///
/// Returns: A list sorted by the connections (skipping yourself)
fn receive_all<T: serde::de::DeserializeOwned>(
fn receive_all<T: serde::de::DeserializeOwned + Send>(
&mut self,
) -> impl Future<Output = Result<Vec<T>, Self::UnicastError>>;
) -> impl Future<Output = Result<Vec<T>, Self::UnicastError>> + Send;

/// Size of the unicasting network including yourself,
/// as such there is n-1 outgoing connections
Expand All @@ -149,25 +149,25 @@ impl<'a, U: Unicast> Unicast for &'a mut U {
(**self).size()
}

fn receive_all<T: serde::de::DeserializeOwned>(
fn receive_all<T: serde::de::DeserializeOwned + Send>(
&mut self,
) -> impl Future<Output = Result<Vec<T>, Self::UnicastError>> {
) -> impl Future<Output = Result<Vec<T>, Self::UnicastError>> + Send {
(**self).receive_all()
}

fn unicast(
&mut self,
msgs: &[impl serde::Serialize + Sync],
) -> impl std::future::Future<Output = Result<(), Self::UnicastError>> {
msgs: &[impl serde::Serialize + Send + Sync],
) -> impl std::future::Future<Output = Result<(), Self::UnicastError>> + Send {
(**self).unicast(msgs)
}

fn symmetric_unicast<T>(
&mut self,
msgs: Vec<T>,
) -> impl Future<Output = Result<Vec<T>, Self::UnicastError>>
) -> impl Future<Output = Result<Vec<T>, Self::UnicastError>> + Send
where
T: serde::Serialize + serde::de::DeserializeOwned + Sync,
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync,
{
(**self).symmetric_unicast(msgs)
}
Expand Down Expand Up @@ -353,7 +353,7 @@ pub enum BroadcastVerificationError<E> {
Other(E),
}

impl<B: Broadcast, D: Digest> Broadcast for VerifiedBroadcast<B, D> {
impl<B: Broadcast, D: Digest + Send> Broadcast for VerifiedBroadcast<B, D> {
type BroadcastError = BroadcastVerificationError<<B as Broadcast>::BroadcastError>;

async fn broadcast(&mut self, msg: &impl serde::Serialize) -> Result<(), Self::BroadcastError> {
Expand All @@ -370,7 +370,7 @@ impl<B: Broadcast, D: Digest> Broadcast for VerifiedBroadcast<B, D> {
fn recv_from<T: serde::de::DeserializeOwned>(
&mut self,
idx: usize,
) -> impl Future<Output = Result<T, Self::BroadcastError>> {
) -> impl Future<Output = Result<T, Self::BroadcastError>> + Send {
self.recv_from(idx)
}

Expand Down
8 changes: 4 additions & 4 deletions src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ pub trait Tuneable {
fn recv_from<T: serde::de::DeserializeOwned>(
&mut self,
idx: usize,
) -> impl Future<Output = Result<T, Self::TuningError>>;
) -> impl Future<Output = Result<T, Self::TuningError>> + Send;

fn send_to<T: serde::Serialize + Sync>(
&mut self,
idx: usize,
msg: &T,
) -> impl Future<Output = Result<(), Self::TuningError>>;
) -> impl Future<Output = Result<(), Self::TuningError>> + Send;
}

impl<'a, R: Tuneable + ?Sized> Tuneable for &'a mut R {
Expand All @@ -118,15 +118,15 @@ impl<'a, R: Tuneable + ?Sized> Tuneable for &'a mut R {
fn recv_from<T: serde::de::DeserializeOwned>(
&mut self,
idx: usize,
) -> impl Future<Output = Result<T, Self::TuningError>> {
) -> impl Future<Output = Result<T, Self::TuningError>> + Send {
(**self).recv_from(idx)
}

fn send_to<T: serde::Serialize + Sync>(
&mut self,
idx: usize,
msg: &T,
) -> impl Future<Output = Result<(), Self::TuningError>> {
) -> impl Future<Output = Result<(), Self::TuningError>> + Send {
(**self).send_to(idx, msg)
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/net/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ impl<C: SplitChannel> Unicast for Network<C> {
}

#[tracing::instrument(skip_all)]
async fn receive_all<T: serde::de::DeserializeOwned>(
async fn receive_all<T: serde::de::DeserializeOwned + Send>(
&mut self,
) -> Result<Vec<T>, Self::UnicastError> {
self.receive_all().await
Expand Down Expand Up @@ -399,7 +399,7 @@ impl<C: SplitChannel> Broadcast for Network<C> {
fn recv_from<T: serde::de::DeserializeOwned>(
&mut self,
idx: usize,
) -> impl Future<Output = Result<T, Self::BroadcastError>> {
) -> impl Future<Output = Result<T, Self::BroadcastError>> + Send {
Tuneable::recv_from(self, idx)
}

Expand Down
4 changes: 2 additions & 2 deletions src/ot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub trait ObliviousSend<C: Channel> {
pkg0: &T,
pkg1: &T,
channel: &mut C,
) -> impl Future<Output = Result<(), Self::Error>>;
) -> impl Future<Output = Result<(), Self::Error>> + Send;
}

pub trait ObliviousReceive<C: Channel> {
Expand All @@ -49,7 +49,7 @@ pub trait ObliviousReceive<C: Channel> {
fn choose<T: serde::de::DeserializeOwned>(
choice: bool,
channel: &mut C,
) -> impl Future<Output = Result<T, Self::Error>>;
) -> impl Future<Output = Result<T, Self::Error>> + Send;
}

/// A Mock OT that provides no security what-so-ever.
Expand Down
Loading

0 comments on commit 9094a1d

Please sign in to comment.