Skip to content

Change accumulator releated signatures #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 135 additions & 54 deletions hyperactor/src/accum.rs

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion hyperactor/src/cap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub(crate) mod sealed {
use async_trait::async_trait;

use crate::PortId;
use crate::accum::ReducerSpec;
use crate::actor::Actor;
use crate::actor::ActorHandle;
use crate::data::Serialized;
Expand All @@ -47,7 +48,11 @@ pub(crate) mod sealed {
}

pub trait CanSplitPort: Send + Sync {
fn split(&self, port_id: PortId, reducer: Option<u64>) -> PortId;
fn split(
&self,
port_id: PortId,
reducer_spec: Option<ReducerSpec>,
) -> anyhow::Result<PortId>;
}

#[async_trait]
Expand Down
57 changes: 33 additions & 24 deletions hyperactor/src/mailbox/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ use crate::OncePortRef;
use crate::PortRef;
use crate::accum;
use crate::accum::Accumulator;
use crate::accum::ReducerSpec;
use crate::actor::Signal;
use crate::actor::remote::USER_PORT_OFFSET;
use crate::cap;
Expand Down Expand Up @@ -1074,10 +1075,10 @@ impl Mailbox {
let (sender, receiver) = mpsc::unbounded_channel::<A::State>();
let port_id = PortId(self.state.actor_id.clone(), port_index);
let state = Mutex::new(A::State::default());
let reducer_typehash = accum.reducer_typehash();
let reducer_spec = accum.reducer_spec();
let enqueue = move |update: A::Update| {
let mut state = state.lock().unwrap();
accum.accumulate(&mut state, update);
accum.accumulate(&mut state, update)?;
let _ = sender.send(state.clone());
Ok(())
};
Expand All @@ -1087,7 +1088,7 @@ impl Mailbox {
port_index,
sender: UnboundedPortSender::Func(Arc::new(enqueue)),
bound: Arc::new(OnceLock::new()),
reducer_typehash: Some(reducer_typehash),
reducer_spec,
},
PortReceiver::new(
receiver,
Expand All @@ -1110,7 +1111,7 @@ impl Mailbox {
port_index: self.state.allocate_port(),
sender: UnboundedPortSender::Func(Arc::new(enqueue)),
bound: Arc::new(OnceLock::new()),
reducer_typehash: None,
reducer_spec: None,
}
}

Expand Down Expand Up @@ -1298,7 +1299,7 @@ impl SplitPortBuffer {
}

impl cap::sealed::CanSplitPort for Mailbox {
fn split(&self, port_id: PortId, reducer_typehash: Option<u64>) -> PortId {
fn split(&self, port_id: PortId, reducer_spec: Option<ReducerSpec>) -> anyhow::Result<PortId> {
fn post(mailbox: &Mailbox, port_id: PortId, msg: Serialized) {
mailbox.post(
MessageEnvelope::new(mailbox.actor_id().clone(), port_id, msg),
Expand All @@ -1313,7 +1314,15 @@ impl cap::sealed::CanSplitPort for Mailbox {
let port_index = self.state.allocate_port();
let split_port = self.actor_id().port_id(port_index);
let mailbox = self.clone();
let reducer = reducer_typehash.and_then(accum::resolve_reducer);
let reducer = reducer_spec
.map(
|ReducerSpec {
typehash,
builder_params,
}| { accum::resolve_reducer(typehash, builder_params) },
)
.transpose()?
.flatten();
let enqueue: Box<
dyn Fn(Serialized) -> Result<(), (Serialized, anyhow::Error)> + Send + Sync,
> = match reducer {
Expand Down Expand Up @@ -1349,7 +1358,7 @@ impl cap::sealed::CanSplitPort for Mailbox {
port_id: split_port.clone(),
},
);
split_port
Ok(split_port)
}
}

Expand All @@ -1375,7 +1384,7 @@ pub struct PortHandle<M: Message> {
bound: Arc<OnceLock<PortId>>,
// Typehash of an optional reducer. When it's defined, we include it in port
/// references to optionally enable incremental accumulation.
reducer_typehash: Option<u64>,
reducer_spec: Option<ReducerSpec>,
}

impl<M: Message> PortHandle<M> {
Expand All @@ -1385,7 +1394,7 @@ impl<M: Message> PortHandle<M> {
port_index,
sender,
bound: Arc::new(OnceLock::new()),
reducer_typehash: None,
reducer_spec: None,
}
}

Expand Down Expand Up @@ -1415,7 +1424,7 @@ impl<M: RemoteMessage> PortHandle<M> {
self.bound
.get_or_init(|| self.mailbox.bind(self).port_id().clone())
.clone(),
self.reducer_typehash.clone(),
self.reducer_spec.clone(),
)
}

Expand All @@ -1433,7 +1442,7 @@ impl<M: Message> Clone for PortHandle<M> {
port_index: self.port_index,
sender: self.sender.clone(),
bound: self.bound.clone(),
reducer_typehash: self.reducer_typehash.clone(),
reducer_spec: self.reducer_spec.clone(),
}
}
}
Expand Down Expand Up @@ -2311,18 +2320,18 @@ mod tests {
// accum port could have reducer typehash
{
let accumulator = accum::max::<u64>();
let reducer_typehash = accumulator.reducer_typehash();
let reducer_spec = accumulator.reducer_spec().unwrap();
let (port, _) = mbox.open_accum_port(accum::max::<u64>());
assert_eq!(port.reducer_typehash, Some(reducer_typehash),);
assert_eq!(port.reducer_spec, Some(reducer_spec.clone()));
let port_ref = port.bind();
assert_eq!(port_ref.reducer_typehash(), &Some(reducer_typehash));
assert_eq!(port_ref.reducer_spec(), &Some(reducer_spec));
}
// normal port should not have reducer typehash
{
let (port, _) = mbox.open_port::<u64>();
assert_eq!(port.reducer_typehash, None);
assert_eq!(port.reducer_spec, None);
let port_ref = port.bind();
assert_eq!(port_ref.reducer_typehash(), &None);
assert_eq!(port_ref.reducer_spec(), &None);
}
}

Expand Down Expand Up @@ -2879,7 +2888,7 @@ mod tests {
port_id2_1: PortId,
}

async fn setup_split_port_ids(reducer_typehash: Option<u64>) -> Setup {
async fn setup_split_port_ids(reducer_spec: Option<ReducerSpec>) -> Setup {
let muxer = MailboxMuxer::new();
let actor0 = Mailbox::new(id!(test[0].actor), BoxedMailboxSender::new(muxer.clone()));
let actor1 = Mailbox::new(id!(test[1].actor1), BoxedMailboxSender::new(muxer.clone()));
Expand All @@ -2891,11 +2900,11 @@ mod tests {
let port_id = port_handle.bind().port_id().clone();

// Split it twice on actor1
let port_id1 = port_id.split(&actor1, reducer_typehash.clone());
let port_id2 = port_id.split(&actor1, reducer_typehash.clone());
let port_id1 = port_id.split(&actor1, reducer_spec.clone()).unwrap();
let port_id2 = port_id.split(&actor1, reducer_spec.clone()).unwrap();

// A split port id can also be split
let port_id2_1 = port_id2.split(&actor1, reducer_typehash);
let port_id2_1 = port_id2.split(&actor1, reducer_spec).unwrap();

Setup {
receiver,
Expand Down Expand Up @@ -2968,7 +2977,7 @@ mod tests {
let _config_guard = config.override_key(crate::config::SPLIT_MAX_BUFFER_SIZE, 1);

let sum_accumulator = accum::sum::<u64>();
let reducer_typehash = sum_accumulator.reducer_typehash();
let reducer_spec = sum_accumulator.reducer_spec();
let Setup {
mut receiver,
actor0,
Expand All @@ -2977,7 +2986,7 @@ mod tests {
port_id1,
port_id2,
port_id2_1,
} = setup_split_port_ids(Some(reducer_typehash)).await;
} = setup_split_port_ids(reducer_spec).await;
post(&actor0, port_id.clone(), 4);
post(&actor1, port_id1.clone(), 2);
post(&actor1, port_id2.clone(), 3);
Expand Down Expand Up @@ -3005,8 +3014,8 @@ mod tests {
let (port_handle, mut receiver) = actor.open_port::<u64>();
let port_id = port_handle.bind().port_id().clone();
// Split it
let reducer_typehash = accum::sum::<u64>().reducer_typehash();
let split_port_id = port_id.split(&actor, Some(reducer_typehash));
let reducer_spec = accum::sum::<u64>().reducer_spec();
let split_port_id = port_id.split(&actor, reducer_spec).unwrap();

// Send 9 messages.
for msg in [1, 5, 3, 4, 2, 91, 92, 93, 94] {
Expand Down
25 changes: 24 additions & 1 deletion hyperactor/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ mod tests {

use super::*;
use crate::PortId;
use crate::accum::ReducerSpec;

// Used to demonstrate a user defined reply type.
#[derive(Debug, PartialEq, Serialize, Deserialize, Named)]
Expand All @@ -352,6 +353,8 @@ mod tests {
let mut bindings = Bindings::default();
let ports = [self.reply0.port_id(), self.reply1.port_id()];
bindings.insert::<PortId>(ports)?;
let reducer_specs = [self.reply0.reducer_spec(), self.reply1.reducer_spec()];
bindings.insert::<Option<ReducerSpec>>(reducer_specs)?;
Ok(bindings)
}
}
Expand All @@ -368,7 +371,13 @@ mod tests {
#[test]
fn test_castable() {
let original_port0 = PortRef::attest(id!(world[0].actor[0][123]));
let original_port1 = PortRef::attest(id!(world[1].actor1[0][456]));
let original_port1 = PortRef::attest_reducible(
id!(world[1].actor1[0][456]),
Some(ReducerSpec {
typehash: 123,
builder_params: None,
}),
);
let my_message = MyMessage {
arg0: true,
arg1: 42,
Expand All @@ -389,6 +398,13 @@ mod tests {
Serialized::serialize(original_port0.port_id()).unwrap(),
Serialized::serialize(original_port1.port_id()).unwrap(),
],
Option::<ReducerSpec>::typehash() => vec![
Serialized::serialize(&None::<ReducerSpec>).unwrap(),
Serialized::serialize(&Some(ReducerSpec {
typehash: 123,
builder_params: None,
})).unwrap(),
],
}),
}
);
Expand All @@ -407,6 +423,13 @@ mod tests {
Serialized::serialize(&new_port_id0).unwrap(),
Serialized::serialize(&new_port_id1).unwrap(),
],
Option::<ReducerSpec>::typehash() => vec![
Serialized::serialize(&None::<ReducerSpec>).unwrap(),
Serialized::serialize(&Some(ReducerSpec {
typehash: 123,
builder_params: None,
})).unwrap(),
],
});
assert_eq!(
erased,
Expand Down
5 changes: 3 additions & 2 deletions hyperactor/src/proc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ use crate::Handler;
use crate::Message;
use crate::Named;
use crate::RemoteMessage;
use crate::accum::ReducerSpec;
use crate::actor::ActorError;
use crate::actor::ActorErrorKind;
use crate::actor::ActorHandle;
Expand Down Expand Up @@ -1161,8 +1162,8 @@ impl<A: Actor> cap::sealed::CanOpenPort for Instance<A> {
}

impl<A: Actor> cap::sealed::CanSplitPort for Instance<A> {
fn split(&self, port_id: PortId, reducer_typehash: Option<u64>) -> PortId {
self.mailbox.split(port_id, reducer_typehash)
fn split(&self, port_id: PortId, reducer_spec: Option<ReducerSpec>) -> anyhow::Result<PortId> {
self.mailbox.split(port_id, reducer_spec)
}
}

Expand Down
23 changes: 14 additions & 9 deletions hyperactor/src/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use crate as hyperactor;
use crate::Named;
use crate::RemoteHandles;
use crate::RemoteMessage;
use crate::accum::ReducerSpec;
use crate::actor::RemoteActor;
use crate::cap;
use crate::data::Serialized;
Expand Down Expand Up @@ -725,8 +726,12 @@ impl PortId {

/// Split this port, returning a new port that relays messages to the port
/// through a local proxy, which may coalesce messages.
pub fn split(&self, caps: &impl cap::CanSplitPort, reducer_typehash: Option<u64>) -> PortId {
caps.split(self.clone(), reducer_typehash)
pub fn split(
&self,
caps: &impl cap::CanSplitPort,
reducer_spec: Option<ReducerSpec>,
) -> anyhow::Result<PortId> {
caps.split(self.clone(), reducer_spec)
}
}

Expand Down Expand Up @@ -760,7 +765,7 @@ pub struct PortRef<M: RemoteMessage> {
Ord = "ignore",
Hash = "ignore"
)]
reducer_typehash: Option<u64>,
reducer_spec: Option<ReducerSpec>,
phantom: PhantomData<M>,
}

Expand All @@ -770,17 +775,17 @@ impl<M: RemoteMessage> PortRef<M> {
pub fn attest(port_id: PortId) -> Self {
Self {
port_id,
reducer_typehash: None,
reducer_spec: None,
phantom: PhantomData,
}
}

/// The caller attests that the provided PortId can be
/// converted to a reachable, typed port reference.
pub(crate) fn attest_reducible(port_id: PortId, reducer_typehash: Option<u64>) -> Self {
pub(crate) fn attest_reducible(port_id: PortId, reducer_spec: Option<ReducerSpec>) -> Self {
Self {
port_id,
reducer_typehash,
reducer_spec,
phantom: PhantomData,
}
}
Expand All @@ -793,8 +798,8 @@ impl<M: RemoteMessage> PortRef<M> {

/// The typehash of this port's reducer, if any. Reducers
/// may be used to coalesce messages sent to a port.
pub fn reducer_typehash(&self) -> &Option<u64> {
&self.reducer_typehash
pub fn reducer_spec(&self) -> &Option<ReducerSpec> {
&self.reducer_spec
}

/// This port's ID.
Expand Down Expand Up @@ -843,7 +848,7 @@ impl<M: RemoteMessage> Clone for PortRef<M> {
fn clone(&self) -> Self {
Self {
port_id: self.port_id.clone(),
reducer_typehash: self.reducer_typehash.clone(),
reducer_spec: self.reducer_spec.clone(),
phantom: PhantomData,
}
}
Expand Down
3 changes: 3 additions & 0 deletions hyperactor_mesh/examples/dining_philosophers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use hyperactor::Handler;
use hyperactor::Instance;
use hyperactor::Named;
use hyperactor::PortRef;
use hyperactor::accum::ReducerSpec;
use hyperactor::message::Bind;
use hyperactor::message::Bindings;
use hyperactor::message::IndexedErasedUnbound;
Expand Down Expand Up @@ -90,6 +91,8 @@ impl Unbind for PhilosopherMessage {
Self::Start(port) => {
let ports = [port.port_id()];
bindings.insert(ports)?;
let reducer_specs = [port.reducer_spec()];
bindings.insert::<Option<ReducerSpec>>(reducer_specs)?;
}
Self::GrantChopstick(_) => {}
}
Expand Down
Loading