From 8e593ede10fe75f014d3c612d1f0f1431601b3cc Mon Sep 17 00:00:00 2001 From: Enrico Ghiorzi Date: Wed, 29 May 2024 22:46:39 +0200 Subject: [PATCH] Optimizations for PG and CS, particularly use of Iterator Signed-off-by: Enrico Ghiorzi --- scan_core/src/channel_system.rs | 153 +++++++++++++------------------- scan_core/src/program_graph.rs | 100 +++++++++------------ scan_fmt_xml/tests/basic.rs | 10 ++- src/main.rs | 8 +- 4 files changed, 121 insertions(+), 150 deletions(-) diff --git a/scan_core/src/channel_system.rs b/scan_core/src/channel_system.rs index 8e52f63..73fbd02 100644 --- a/scan_core/src/channel_system.rs +++ b/scan_core/src/channel_system.rs @@ -200,23 +200,22 @@ impl ChannelSystemBuilder { effect: CsExpression, ) -> Result<(), CsError> { if action.0 != pg_id { - return Err(CsError::ActionNotInPg(action, pg_id)); - } - if var.0 != pg_id { - return Err(CsError::VarNotInPg(var, pg_id)); - } - let effect = PgExpression::try_from((pg_id, effect))?; - // Communications cannot have effects - if self.communications.contains_key(&action) { - return Err(CsError::ActionIsCommunication(action)); + Err(CsError::ActionNotInPg(action, pg_id)) + } else if var.0 != pg_id { + Err(CsError::VarNotInPg(var, pg_id)) + } else if self.communications.contains_key(&action) { + // Communications cannot have effects + Err(CsError::ActionIsCommunication(action)) + } else { + let effect = PgExpression::try_from((pg_id, effect))?; + self.program_graphs + .get_mut(pg_id.0) + .ok_or(CsError::MissingPg(pg_id)) + .and_then(|pg| { + pg.add_effect(action.1, var.1, effect) + .map_err(|err| CsError::ProgramGraph(pg_id, err)) + }) } - self.program_graphs - .get_mut(pg_id.0) - .ok_or(CsError::MissingPg(pg_id)) - .and_then(|pg| { - pg.add_effect(action.1, var.1, effect) - .map_err(|err| CsError::ProgramGraph(pg_id, err)) - }) } pub fn new_location(&mut self, pg_id: PgId) -> Result { @@ -235,25 +234,24 @@ impl ChannelSystemBuilder { guard: Option, ) -> Result<(), CsError> { if action.0 != pg_id { - return Err(CsError::ActionNotInPg(action, pg_id)); - } - if pre.0 != pg_id { - return Err(CsError::LocationNotInPg(pre, pg_id)); - } - if post.0 != pg_id { - return Err(CsError::LocationNotInPg(post, pg_id)); + Err(CsError::ActionNotInPg(action, pg_id)) + } else if pre.0 != pg_id { + Err(CsError::LocationNotInPg(pre, pg_id)) + } else if post.0 != pg_id { + Err(CsError::LocationNotInPg(post, pg_id)) + } else { + // Turn CsExpression into a PgExpression for Program Graph pg_id + let guard = guard + .map(|guard| PgExpression::try_from((pg_id, guard))) + .transpose()?; + self.program_graphs + .get_mut(pg_id.0) + .ok_or(CsError::MissingPg(pg_id)) + .and_then(|pg| { + pg.add_transition(pre.1, action.1, post.1, guard) + .map_err(|err| CsError::ProgramGraph(pg_id, err)) + }) } - // Turn CsExpression into a PgExpression for Program Graph pg_id - let guard = guard - .map(|guard| PgExpression::try_from((pg_id, guard))) - .transpose()?; - self.program_graphs - .get_mut(pg_id.0) - .ok_or(CsError::MissingPg(pg_id)) - .and_then(|pg| { - pg.add_transition(pre.1, action.1, post.1, guard) - .map_err(|err| CsError::ProgramGraph(pg_id, err)) - }) } pub fn new_channel(&mut self, var_type: Type, capacity: Option) -> Channel { @@ -297,7 +295,7 @@ impl ChannelSystemBuilder { } }; if channel_type != message_type { - return Err(CsError::ProgramGraph(pg_id, PgError::Mismatched)); + return Err(CsError::ProgramGraph(pg_id, PgError::TypeMismatch)); } let action = self.new_action(pg_id)?; self.communications.insert(action, (channel, message)); @@ -305,15 +303,16 @@ impl ChannelSystemBuilder { } pub fn build(mut self) -> ChannelSystem { - self.program_graphs.shrink_to_fit(); + let mut program_graphs: Vec = self + .program_graphs + .into_iter() + .map(|builder| builder.build()) + .collect(); + program_graphs.shrink_to_fit(); self.channels.shrink_to_fit(); self.communications.shrink_to_fit(); ChannelSystem { - program_graphs: self - .program_graphs - .into_iter() - .map(|builder| builder.build()) - .collect(), + program_graphs, communications: Rc::new(self.communications), message_queue: vec![Vec::default(); self.channels.len()], channels: Rc::new(self.channels), @@ -330,18 +329,19 @@ pub struct ChannelSystem { } impl ChannelSystem { - // Is this function optimized? Does it unnecessarily copy data? - pub fn possible_transitions(&self) -> Vec<(PgId, CsAction, CsLocation)> { + pub fn possible_transitions<'a>( + &'a self, + ) -> impl Iterator + 'a { self.program_graphs .iter() .enumerate() - .flat_map(|(id, pg)| { + .flat_map(move |(id, pg)| { let pg_id = PgId(id); pg.possible_transitions() - .iter() - .filter_map(|(action, post)| { - let action = CsAction(pg_id, *action); - let post = CsLocation(pg_id, *post); + .into_iter() + .filter_map(move |(action, post)| { + let action = CsAction(pg_id, action); + let post = CsLocation(pg_id, post); if self.communications.contains_key(&action) && self.check_communication(pg_id, action).is_err() { @@ -350,49 +350,24 @@ impl ChannelSystem { Some((pg_id, action, post)) } }) - .collect::>() }) - .collect::>() } fn check_communication(&self, pg_id: PgId, action: CsAction) -> Result<(), CsError> { if action.0 != pg_id { - return Err(CsError::ActionNotInPg(action, pg_id)); - } - if let Some((channel, message)) = self.communications.get(&action) { - let (_, capacity) = self - .channels - .get(channel.0) - .expect("communication has been verified before"); - let queue = self - .message_queue - .get(channel.0) - .expect("communication has been verified before"); + Err(CsError::ActionNotInPg(action, pg_id)) + } else if let Some((channel, message)) = self.communications.get(&action) { + let (_, capacity) = self.channels[channel.0]; + let queue = &self.message_queue[channel.0]; + // Channel capacity must never be exeeded! + assert!(capacity.is_none() || capacity.is_some_and(|cap| queue.len() <= cap)); match message { - Message::Send(_) => { - let len = queue.len(); - // Channel capacity must never be exeeded! - assert!(capacity.is_none() || capacity.is_some_and(|c| len <= c)); - if capacity.is_some_and(|c| c == len) { - Err(CsError::OutOfCapacity(*channel)) - } else { - Ok(()) - } - } - Message::Receive(_) => { - if queue.is_empty() { - Err(CsError::Empty(*channel)) - } else { - Ok(()) - } - } - Message::ProbeEmptyQueue => { - if queue.is_empty() { - Ok(()) - } else { - Err(CsError::Empty(*channel)) - } + Message::Send(_) if capacity.is_some_and(|cap| queue.len() == cap) => { + Err(CsError::OutOfCapacity(*channel)) } + Message::Receive(_) if queue.is_empty() => Err(CsError::Empty(*channel)), + Message::ProbeEmptyQueue if !queue.is_empty() => Err(CsError::Empty(*channel)), + _ => Ok(()), } } else { Err(CsError::NoCommunication(action)) @@ -423,10 +398,8 @@ impl ChannelSystem { .map_err(|err| CsError::ProgramGraph(pg_id, err))?; // If the action is a communication, send/receive the message if let Some((channel, message)) = self.communications.get(&action) { - let queue = self - .message_queue - .get_mut(channel.0) - .expect("communication has been verified before"); + // communication has been verified before so there is a queue for channel.0 + let queue = &mut self.message_queue[channel.0]; match message { Message::Send(effect) => { let effect = (pg_id, effect.to_owned()).try_into()?; @@ -556,11 +529,11 @@ mod tests { cs.add_transition(pg2, initial2, receive, post2, None)?; let mut cs = cs.build(); - assert_eq!(cs.possible_transitions().len(), 1); + assert_eq!(cs.possible_transitions().count(), 1); cs.transition(pg1, send, post1)?; cs.transition(pg2, receive, post2)?; - assert!(cs.possible_transitions().is_empty()); + assert_eq!(cs.possible_transitions().count(), 0); Ok(()) } } diff --git a/scan_core/src/program_graph.rs b/scan_core/src/program_graph.rs index ece5596..99168d4 100644 --- a/scan_core/src/program_graph.rs +++ b/scan_core/src/program_graph.rs @@ -34,7 +34,7 @@ pub enum PgError { #[error("location {0:?} does not belong to this program graph")] MissingLocation(Location), #[error("type mismatch")] - Mismatched, + TypeMismatch, #[error("location {0:?} does not belong to this program graph")] NonExistingVar(Var), #[error("There is no such transition")] @@ -113,7 +113,7 @@ impl ProgramGraphBuilder { .ok_or(PgError::MissingAction(action)) .map(|effects| effects.push((var, effect))) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } @@ -141,12 +141,9 @@ impl ProgramGraphBuilder { Err(PgError::MissingAction(action)) } else if guard.is_some() && !matches!(self.r#type(guard.as_ref().unwrap())?, Type::Boolean) { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } else { - let _ = self - .transitions - .get_mut(pre.0) - .expect("location existance already checked") + let _ = self.transitions[pre.0] .entry((action, post)) .and_modify(|previous_guard| { if let Some(guard) = guard.to_owned() { @@ -164,7 +161,7 @@ impl ProgramGraphBuilder { } } }) - .or_insert(guard.to_owned()); + .or_insert(guard); Ok(()) } } @@ -179,7 +176,6 @@ impl ProgramGraphBuilder { .map(|e| self.r#type(e)) .collect::, PgError>>()?, )), - // PgExpression::Const(val) => Ok(val.r#type()), PgExpression::Var(var) => self .vars .get(var.0) @@ -195,7 +191,7 @@ impl ProgramGraphBuilder { { Ok(Type::Boolean) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Implies(props) => { @@ -204,21 +200,21 @@ impl ProgramGraphBuilder { { Ok(Type::Boolean) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Not(prop) => { if matches!(self.r#type(&prop)?, Type::Boolean) { Ok(Type::Boolean) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Opposite(expr) => { if matches!(self.r#type(&expr)?, Type::Integer) { Ok(Type::Integer) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Sum(exprs) | PgExpression::Mult(exprs) => { @@ -231,7 +227,7 @@ impl ProgramGraphBuilder { { Ok(Type::Integer) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Equal(exprs) @@ -244,7 +240,7 @@ impl ProgramGraphBuilder { { Ok(Type::Boolean) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Component(index, expr) => { @@ -254,7 +250,7 @@ impl ProgramGraphBuilder { .cloned() .ok_or(PgError::MissingComponent(*index)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } } @@ -293,10 +289,8 @@ pub struct ProgramGraph { } impl ProgramGraph { - pub fn possible_transitions(&self) -> Vec<(Action, Location)> { - self.transitions - .get(self.current_location.0) - .unwrap_or(&HashMap::new()) + pub fn possible_transitions<'a>(&'a self) -> impl Iterator + 'a { + self.transitions[self.current_location.0] .iter() .filter_map(|((action, post), guard)| { if let Some(guard) = guard { @@ -313,14 +307,10 @@ impl ProgramGraph { Some((*action, *post)) } }) - .collect::>() } pub fn transition(&mut self, action: Action, post_state: Location) -> Result<(), PgError> { - let guard = self - .transitions - .get(self.current_location.0) - .expect("location must exist") + let guard = self.transitions[self.current_location.0] .get(&(action, post_state)) .ok_or(PgError::NoTransition)?; if guard.as_ref().map_or(true, |guard| { @@ -330,18 +320,13 @@ impl ProgramGraph { panic!("guard is not a boolean"); } }) { - for (var, effect) in self - .effects - .get(action.0) - .expect("action has been validated before") - { + for (var, effect) in &self.effects[action.0] { // Not using the 'Self::assign' method because: // - borrow checker // - effects are validated before, so no need to type-check again - *self - .vars - .get_mut(var.0) - .expect("effect has been validated before") = self.eval(effect)?; + self.vars[var.0] = self + .eval(&effect) + .expect("effect has already been validated"); } self.current_location = post_state; Ok(()) @@ -371,7 +356,7 @@ impl ProgramGraph { .cloned() .ok_or(PgError::MissingComponent(*index)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::And(props) => Ok(Val::Boolean( @@ -381,7 +366,7 @@ impl ProgramGraph { if let Val::Boolean(val) = self.eval(prop)? { Ok(val) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } }) .collect::, PgError>>()? @@ -395,7 +380,7 @@ impl ProgramGraph { if let Val::Boolean(val) = self.eval(prop)? { Ok(val) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } }) .collect::, PgError>>()? @@ -408,21 +393,21 @@ impl ProgramGraph { { Ok(Val::Boolean(rhs || !lhs)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Not(prop) => { if let Val::Boolean(arg) = self.eval(prop)? { Ok(Val::Boolean(!arg)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Opposite(expr) => { if let Val::Integer(arg) = self.eval(expr)? { Ok(Val::Integer(-arg)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Sum(exprs) => Ok(Val::Integer( @@ -432,7 +417,7 @@ impl ProgramGraph { if let Val::Integer(val) = self.eval(prop)? { Ok(val) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } }) .collect::, PgError>>()? @@ -446,7 +431,7 @@ impl ProgramGraph { if let Val::Integer(val) = self.eval(prop)? { Ok(val) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } }) .collect::, PgError>>()? @@ -459,7 +444,7 @@ impl ProgramGraph { { Ok(Val::Boolean(lhs == rhs)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Greater(exprs) => { @@ -468,7 +453,7 @@ impl ProgramGraph { { Ok(Val::Boolean(lhs > rhs)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::GreaterEq(exprs) => { @@ -477,7 +462,7 @@ impl ProgramGraph { { Ok(Val::Boolean(lhs >= rhs)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::Less(exprs) => { @@ -486,7 +471,7 @@ impl ProgramGraph { { Ok(Val::Boolean(lhs < rhs)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } PgExpression::LessEq(exprs) => { @@ -495,7 +480,7 @@ impl ProgramGraph { { Ok(Val::Boolean(lhs <= rhs)) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } } @@ -511,7 +496,7 @@ impl ProgramGraph { *var_content = val; Ok(previous_val) } else { - Err(PgError::Mismatched) + Err(PgError::TypeMismatch) } } } @@ -528,9 +513,12 @@ mod tests { let action = builder.new_action(); builder.add_transition(initial, action, r#final, None)?; let mut pg = builder.build(); - assert_eq!(pg.possible_transitions(), vec![(action, r#final)]); + assert_eq!( + pg.possible_transitions().collect::>(), + vec![(action, r#final)] + ); pg.transition(action, r#final)?; - assert!(pg.possible_transitions().is_empty()); + assert_eq!(pg.possible_transitions().count(), 0); Ok(()) } @@ -581,17 +569,17 @@ mod tests { builder.add_transition(center, move_left, left, Some(out_of_charge))?; // Execution let mut pg = builder.build(); - assert_eq!(pg.possible_transitions().len(), 1); + assert_eq!(pg.possible_transitions().count(), 1); pg.transition(initialize, center)?; - assert_eq!(pg.possible_transitions().len(), 2); + assert_eq!(pg.possible_transitions().count(), 2); pg.transition(move_right, right)?; - assert_eq!(pg.possible_transitions().len(), 1); + assert_eq!(pg.possible_transitions().count(), 1); pg.transition(move_right, right).expect_err("already right"); - assert_eq!(pg.possible_transitions().len(), 1); + assert_eq!(pg.possible_transitions().count(), 1); pg.transition(move_left, center)?; - assert_eq!(pg.possible_transitions().len(), 2); + assert_eq!(pg.possible_transitions().count(), 2); pg.transition(move_left, left)?; - assert_eq!(pg.possible_transitions().len(), 0); + assert_eq!(pg.possible_transitions().count(), 0); pg.transition(move_left, left).expect_err("battery = 0"); Ok(()) } diff --git a/scan_fmt_xml/tests/basic.rs b/scan_fmt_xml/tests/basic.rs index 85a5daf..4a996e8 100644 --- a/scan_fmt_xml/tests/basic.rs +++ b/scan_fmt_xml/tests/basic.rs @@ -58,8 +58,14 @@ fn test(file: PathBuf) -> anyhow::Result<()> { let parser = Parser::parse(file)?; let mut model = Sc2CsVisitor::visit(parser)?; let mut steps = 0; - assert!(!model.cs.possible_transitions().is_empty()); - while let Some((pg_id, act, loc)) = model.cs.possible_transitions().first().cloned() { + assert!(model.cs.possible_transitions().count() > 0); + while let Some((pg_id, act, loc)) = model + .cs + .possible_transitions() + .take(1) + .collect::>() + .pop() + { model.cs.transition(pg_id, act, loc)?; steps += 1; if steps >= MAXSTEP { diff --git a/src/main.rs b/src/main.rs index 9732f3c..bba0ac8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,8 +72,12 @@ fn main() -> Result<(), Box> { let mut model = Sc2CsVisitor::visit(model)?; println!("Transitions list:"); let mut trans: u32 = 0; - while let Some((pg_id, action, destination)) = - model.cs.possible_transitions().first().cloned() + while let Some((pg_id, action, destination)) = model + .cs + .possible_transitions() + .take(1) + .collect::>() + .pop() { let pg = model .fsm_names