diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 24a674051..67136329b 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -17,6 +17,7 @@ bench = false [dependencies] hugr-core = { path = "../hugr-core", version = "0.13.3" } +ascent = { version = "0.7.0" } itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } @@ -28,3 +29,6 @@ extension_inference = ["hugr-core/extension_inference"] [dev-dependencies] rstest = { workspace = true } +proptest = { workspace = true } +proptest-derive = { workspace = true } +proptest-recurse = { version = "0.5.0" } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs new file mode 100644 index 000000000..bb3023c38 --- /dev/null +++ b/hugr-passes/src/dataflow.rs @@ -0,0 +1,124 @@ +#![warn(missing_docs)] +//! Dataflow analysis of Hugrs. + +mod datalog; +pub use datalog::Machine; +mod value_row; + +mod results; +pub use results::{AnalysisResults, TailLoopTermination}; + +mod partial_value; +pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; + +use hugr_core::ops::constant::OpaqueValue; +use hugr_core::ops::{ExtensionOp, Value}; +use hugr_core::types::TypeArg; +use hugr_core::{Hugr, Node}; + +/// Clients of the dataflow framework (particular analyses, such as constant folding) +/// must implement this trait (including providing an appropriate domain type `V`). +pub trait DFContext: ConstLoader { + /// Given lattice values for each input, update lattice values for the (dataflow) outputs. + /// For extension ops only, excluding [MakeTuple] and [UnpackTuple] which are handled automatically. + /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] + /// which is the correct value to leave if nothing can be deduced about that output. + /// (The default does nothing, i.e. leaves `Top` for all outputs.) + /// + /// [MakeTuple]: hugr_core::extension::prelude::MakeTuple + /// [UnpackTuple]: hugr_core::extension::prelude::UnpackTuple + fn interpret_leaf_op( + &mut self, + _node: Node, + _e: &ExtensionOp, + _ins: &[PartialValue], + _outs: &mut [PartialValue], + ) { + } +} + +/// A location where a [Value] could be find in a Hugr. That is, +/// (perhaps deeply nested within [Value::Sum]s) within a [Node] +/// that is a [Const](hugr_core::ops::Const). +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ConstLocation<'a> { + /// The specified-index'th field of the [Value::Sum] constant identified by the RHS + Field(usize, &'a ConstLocation<'a>), + /// The entire ([Const::value](hugr_core::ops::Const::value)) of the node. + Node(Node), +} + +impl From for ConstLocation<'_> { + fn from(value: Node) -> Self { + ConstLocation::Node(value) + } +} + +/// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. +/// Implementors will likely want to override some/all of [Self::value_from_opaque], +/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// are "correct" but maximally conservative (minimally informative). +pub trait ConstLoader { + /// Produces an abstract value from an [OpaqueValue], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_opaque(&self, _loc: ConstLocation, _val: &OpaqueValue) -> Option { + None + } + + /// Produces an abstract value from a Hugr in a [Value::Function], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_const_hugr(&self, _loc: ConstLocation, _h: &Hugr) -> Option { + None + } + + /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node + /// (that has been loaded via a [LoadFunction]), if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + /// + /// [FuncDefn]: hugr_core::ops::FuncDefn + /// [FuncDecl]: hugr_core::ops::FuncDecl + /// [LoadFunction]: hugr_core::ops::LoadFunction + fn value_from_function(&self, _node: Node, _type_args: &[TypeArg]) -> Option { + None + } +} + +/// Produces a [PartialValue] from a constant. Traverses [Sum](Value::Sum) constants +/// to their leaves ([Value::Extension] and [Value::Function]), +/// converts these using [ConstLoader::value_from_opaque] and [ConstLoader::value_from_const_hugr], +/// and builds nested [PartialValue::new_variant] to represent the structure. +fn partial_from_const<'a, V>( + cl: &impl ConstLoader, + loc: impl Into>, + cst: &Value, +) -> PartialValue { + let loc = loc.into(); + match cst { + Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { + let elems = values + .iter() + .enumerate() + .map(|(idx, elem)| partial_from_const(cl, ConstLocation::Field(idx, &loc), elem)); + PartialValue::new_variant(*tag, elems) + } + Value::Extension { e } => cl + .value_from_opaque(loc, e) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + Value::Function { hugr } => cl + .value_from_const_hugr(loc, hugr) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + } +} + +/// A row of inputs to a node contains bottom (can't happen, the node +/// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). +pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( + elements: impl IntoIterator>, +) -> bool { + elements.into_iter().any(PartialValue::contains_bottom) +} + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs new file mode 100644 index 000000000..172d87c26 --- /dev/null +++ b/hugr-passes/src/dataflow/datalog.rs @@ -0,0 +1,397 @@ +//! [ascent] datalog implementation of analysis. + +use std::collections::hash_map::RandomState; +use std::collections::HashSet; // Moves to std::hash in Rust 1.76 + +use ascent::lattice::BoundedLattice; +use itertools::Itertools; + +use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; +use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; + +use super::value_row::ValueRow; +use super::{ + partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, + PartialValue, +}; + +type PV = PartialValue; + +/// Basic structure for performing an analysis. Usage: +/// 1. Make a new instance via [Self::new()] +/// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] and/or +/// [Self::prepopulate_df_inputs] with initial values. +/// For example, to analyse a [Module](OpType::Module)-rooted Hugr as a library, +/// [Self::prepopulate_df_inputs] can be used on each externally-callable +/// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. +/// 3. Call [Self::run] to produce [AnalysisResults] +pub struct Machine(H, Vec<(Node, IncomingPort, PartialValue)>); + +impl Machine { + /// Create a new Machine to analyse the given Hugr(View) + pub fn new(hugr: H) -> Self { + Self(hugr, Default::default()) + } +} + +impl Machine { + /// Provide initial values for a wire - these will be `join`d with any computed. + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + self.1.extend( + self.0 + .linked_inputs(w.node(), w.source()) + .map(|(n, inp)| (n, inp, v.clone())), + ); + } + + /// Provide initial values for the inputs to a [DataflowParent](hugr_core::ops::OpTag::DataflowParent) + /// (that is, values on the wires leaving the [Input](OpType::Input) child thereof). + /// Any out-ports of said same `Input` node, not given values by `in_values`, are set to [PartialValue::Top]. + pub fn prepopulate_df_inputs( + &mut self, + parent: Node, + in_values: impl IntoIterator)>, + ) { + // Put values onto out-wires of Input node + let [inp, _] = self.0.get_io(parent).unwrap(); + let mut vals = vec![PartialValue::Top; self.0.signature(inp).unwrap().output_types().len()]; + for (ip, v) in in_values { + vals[ip.index()] = v; + } + for (i, v) in vals.into_iter().enumerate() { + self.prepopulate_wire(Wire::new(inp, i), v); + } + } + + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// given initial values for some of the root node inputs. For a + /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. + /// The context passed in allows interpretation of leaf operations. + /// + /// # Panics + /// May panic in various ways if the Hugr is invalid; + /// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`. + pub fn run( + mut self, + context: impl DFContext, + in_values: impl IntoIterator)>, + ) -> AnalysisResults { + let mut in_values = in_values.into_iter(); + let root = self.0.root(); + // Some nodes do not accept values as dataflow inputs - for these + // we must find the corresponding Input node. + let input_node_parent = match self.0.get_optype(root) { + OpType::Module(_) => { + let main = self.0.children(root).find(|n| { + self.0 + .get_optype(*n) + .as_func_defn() + .is_some_and(|f| f.name == "main") + }); + if main.is_none() && in_values.next().is_some() { + panic!("Cannot give inputs to module with no 'main'"); + } + main + } + OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => Some(root), + // Could also do Dfg above, but ok here too: + _ => None, // Just feed into node inputs + }; + // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis + // (Consider: for a conditional that selects *either* the unknown input *or* value V, + // analysis must produce Top == we-know-nothing, not `V` !) + if let Some(p) = input_node_parent { + self.prepopulate_df_inputs( + p, + in_values.map(|(p, v)| (OutgoingPort::from(p.index()), v)), + ); + } else { + // Put values onto in-wires of root node, datalog will do the rest + self.1.extend(in_values.map(|(p, v)| (root, p, v))); + let got_inputs: HashSet<_, RandomState> = self + .1 + .iter() + .filter_map(|(n, p, _)| (n == &root).then_some(*p)) + .collect(); + for p in self.0.signature(root).unwrap_or_default().input_ports() { + if !got_inputs.contains(&p) { + self.1.push((root, p, PartialValue::Top)); + } + } + } + // Note/TODO, if analysis is running on a subregion then we should do similar + // for any nonlocal edges providing values from outside the region. + run_datalog(context, self.0, self.1) + } +} + +pub(super) fn run_datalog( + mut ctx: impl DFContext, + hugr: H, + in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, +) -> AnalysisResults { + // ascent-(macro-)generated code generates a bunch of warnings, + // keep code in here to a minimum. + #![allow( + clippy::clone_on_copy, + clippy::unused_enumerate_index, + clippy::collapsible_if + )] + let all_results = ascent::ascent_run! { + pub(super) struct AscentProgram; + relation node(Node); // exists in the hugr + relation in_wire(Node, IncomingPort); // has an of `EdgeKind::Value` + relation out_wire(Node, OutgoingPort); // has an of `EdgeKind::Value` + relation parent_of_node(Node, Node); // is parent of + relation input_child(Node, Node); // has 1st child that is its `Input` + relation output_child(Node, Node); // has 2nd child that is its `Output` + lattice out_wire_value(Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(Node, ValueRow); // 's inputs are + + node(n) <-- for n in hugr.nodes(); + + in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n); // (and likewise) + + parent_of_node(parent, child) <-- + node(child), if let Some(parent) = hugr.get_parent(*child); + + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = hugr.get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = hugr.get_io(*parent); + + // Initialize all wires to bottom + out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); + + // Outputs to inputs + in_wire_value(n, ip, v) <-- in_wire(n, ip), + if let Some((m, op)) = hugr.single_linked_output(*n, *ip), + out_wire_value(m, op, v); + + // Prepopulate in_wire_value from in_wire_value_proto. + in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); + in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), + node(n), + if let Some(sig) = hugr.signature(*n), + if sig.input_ports().contains(p); + + // Assemble node_in_value_row from in_wire_value's + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n); + node_in_value_row(n, ValueRow::new(hugr.signature(*n).unwrap().input_count()).set(p.index(), v.clone())) <-- in_wire_value(n, p, v); + + // Interpret leaf ops + out_wire_value(n, p, v) <-- + node(n), + let op_t = hugr.get_optype(*n), + if !op_t.is_container(), + if let Some(sig) = op_t.dataflow_signature(), + node_in_value_row(n, vs), + if let Some(outs) = propagate_leaf_op(&mut ctx, &hugr, *n, &vs[..], sig.output_count()), + for (p, v) in (0..).map(OutgoingPort::from).zip(outs); + + // DFG -------------------- + relation dfg_node(Node); // is a `DFG` + dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); + + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + input_child(dfg, i), in_wire_value(dfg, p, v); + + out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + output_child(dfg, o), in_wire_value(o, p, v); + + // TailLoop -------------------- + // inputs of tail loop propagate to Input node of child region + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), + if hugr.get_optype(*tl).is_tail_loop(), + input_child(tl, i), + in_wire_value(tl, p, v); + + // Output node of child region propagate to Input node of child region + out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), + input_child(tl, in_n), + output_child(tl, out_n), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ...and select just what's possible for CONTINUE_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), + for (out_p, v) in fields.enumerate(); + + // Output node of child region propagate to outputs of tail loop + out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), + output_child(tl, out_n), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ... and select just what's possible for BREAK_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), + for (out_p, v) in fields.enumerate(); + + // Conditional -------------------- + // is a `Conditional` and its 'th child (a `Case`) is : + relation case_node(Node, usize, Node); + case_node(cond, i, case) <-- node(cond), + if hugr.get_optype(*cond).is_conditional(), + for (i, case) in hugr.children(*cond).enumerate(), + if hugr.get_optype(case).is_case(); + + // inputs of conditional propagate into case nodes + out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- + case_node(cond, case_index, case), + input_child(case, i_node), + node_in_value_row(cond, in_row), + let conditional = hugr.get_optype(*cond).as_conditional().unwrap(), + if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + for (out_p, v) in fields.enumerate(); + + // outputs of case nodes propagate to outputs of conditional *if* case reachable + out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- + case_node(cond, _i, case), + case_reachable(cond, case), + output_child(case, o), + in_wire_value(o, o_p, v); + + // In `Conditional` , child `Case` is reachable given our knowledge of predicate: + relation case_reachable(Node, Node); + case_reachable(cond, case) <-- case_node(cond, i, case), + in_wire_value(cond, IncomingPort::from(0), v), + if v.supports_tag(*i); + + // CFG -------------------- + relation cfg_node(Node); // is a `CFG` + cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); + + // In `CFG` , basic block is reachable given our knowledge of predicates: + relation bb_reachable(Node, Node); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); + bb_reachable(cfg, bb) <-- cfg_node(cfg), + bb_reachable(cfg, pred), + output_child(pred, pred_out), + in_wire_value(pred_out, IncomingPort::from(0), predicate), + for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), + if predicate.supports_tag(tag); + + // Inputs of CFG propagate to entry block + out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- + cfg_node(cfg), + if let Some(entry) = hugr.children(*cfg).next(), + input_child(entry, i_node), + in_wire_value(cfg, p, v); + + // In `CFG` , values fed along a control-flow edge to + // come out of Value outports of : + relation _cfg_succ_dest(Node, Node, Node); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); + _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), + for blk in hugr.children(*cfg), + if hugr.get_optype(blk).is_dataflow_block(), + input_child(blk, inp); + + // Outputs of each reachable block propagated to successor block or CFG itself + out_wire_value(dest, OutgoingPort::from(out_p), v) <-- + bb_reachable(cfg, pred), + if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), + output_child(pred, out_n), + _cfg_succ_dest(cfg, succ, dest), + node_in_value_row(out_n, out_in_row), + if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + for (out_p, v) in fields.enumerate(); + + // Call -------------------- + relation func_call(Node, Node); // is a `Call` to `FuncDefn` + func_call(call, func_defn) <-- + node(call), + if hugr.get_optype(*call).is_call(), + if let Some(func_defn) = hugr.static_source(*call); + + out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + input_child(func, inp), + in_wire_value(call, p, v); + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + output_child(func, outp), + in_wire_value(outp, p, v); + }; + let out_wire_values = all_results + .out_wire_value + .iter() + .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(); + AnalysisResults { + hugr, + out_wire_values, + in_wire_value: all_results.in_wire_value, + case_reachable: all_results.case_reachable, + bb_reachable: all_results.bb_reachable, + } +} + +fn propagate_leaf_op( + ctx: &mut impl DFContext, + hugr: &impl HugrView, + n: Node, + ins: &[PV], + num_outs: usize, +) -> Option> { + match hugr.get_optype(n) { + // Handle basics here. We could instead leave these to DFContext, + // but at least we'd want these impls to be easily reusable. + op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( + 0, + ins.iter().cloned(), + )])), + op if op.cast::().is_some() => { + let elem_tys = op.cast::().unwrap().0; + let tup = ins.iter().exactly_one().unwrap(); + tup.variant_values(0, elem_tys.len()) + .map(ValueRow::from_iter) + } + OpType::Tag(t) => Some(ValueRow::from_iter([PV::new_variant( + t.tag, + ins.iter().cloned(), + )])), + OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent + OpType::Call(_) => None, // handled via Input/Output of FuncDefn + OpType::Const(_) => None, // handled by LoadConstant: + OpType::LoadConstant(load_op) => { + assert!(ins.is_empty()); // static edge, so need to find constant + let const_node = hugr + .single_linked_output(n, load_op.constant_port()) + .unwrap() + .0; + let const_val = hugr.get_optype(const_node).as_const().unwrap().value(); + Some(ValueRow::singleton(partial_from_const(ctx, n, const_val))) + } + OpType::LoadFunction(load_op) => { + assert!(ins.is_empty()); // static edge + let func_node = hugr + .single_linked_output(n, load_op.function_port()) + .unwrap() + .0; + // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself + Some(ValueRow::singleton( + ctx.value_from_function(func_node, &load_op.type_args) + .map_or(PV::Top, PV::Value), + )) + } + OpType::ExtensionOp(e) => { + Some(ValueRow::from_iter(if row_contains_bottom(ins) { + // So far we think one or more inputs can't happen. + // So, don't pollute outputs with Top, and wait for better knowledge of inputs. + vec![PartialValue::Bottom; num_outs] + } else { + // Interpret op using DFContext + // Default to Top i.e. can't figure out anything about the outputs + let mut outs = vec![PartialValue::Top; num_outs]; + // It might be nice to convert `ins` to [(IncomingPort, Value)], or some + // other concrete value, for the context, but PV contains more information, + // and try_into_concrete may fail. + ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); + outs + })) + } + o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + } +} diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs new file mode 100644 index 000000000..f2a497806 --- /dev/null +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -0,0 +1,706 @@ +use ascent::lattice::BoundedLattice; +use ascent::Lattice; +use hugr_core::ops::Value; +use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use itertools::{zip_eq, Itertools}; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use thiserror::Error; + +use super::row_contains_bottom; + +/// Trait for an underlying domain of abstract values which can form the *elements* of a +/// [PartialValue] and thus be used in dataflow analysis. +pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { + /// Computes the join of two values (i.e. towards `Top``), if this is representable + /// within the underlying domain. Return the new value, and whether this is different from + /// the old `self`. + /// + /// If the join is not representable, return `None` - i.e., we should use [PartialValue::Top]. + /// + /// The default checks equality between `self` and `other` and returns `(self,false)` if + /// the two are identical, otherwise `None`. + fn try_join(self, other: Self) -> Option<(Self, bool)> { + (self == other).then_some((self, false)) + } + + /// Computes the meet of two values (i.e. towards `Bottom`), if this is representable + /// within the underlying domain. Return the new value, and whether this is different from + /// the old `self`. + /// If the meet is not representable, return `None` - i.e., we should use [PartialValue::Bottom]. + /// + /// The default checks equality between `self` and `other` and returns `(self, false)` if + /// the two are identical, otherwise `None`. + fn try_meet(self, other: Self) -> Option<(Self, bool)> { + (self == other).then_some((self, false)) + } +} + +/// Represents a sum with a single/known tag, abstracted over the representation of the elements. +/// (Identical to [Sum](hugr_core::ops::constant::Sum) except for the type abstraction.) +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Sum { + /// The tag index of the variant. + pub tag: usize, + /// The value of the variant. + /// + /// Sum variants are always a row of values, hence the Vec. + pub values: Vec, + /// The full type of the Sum, including the other variants. + pub st: SumType, +} + +/// A representation of a value of [SumType], that may have one or more possible tags, +/// with a [PartialValue] representation of each element-value of each possible tag. +#[derive(PartialEq, Clone, Eq)] +pub struct PartialSum(pub HashMap>>); + +impl PartialSum { + /// New instance for a single known tag. + /// (Multi-tag instances can be created via [Self::try_join_mut].) + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { + Self(HashMap::from([(tag, Vec::from_iter(values))])) + } + + /// The number of possible variants we know about. (NOT the number + /// of tags possible for the value's type, whatever [SumType] that might be.) + pub fn num_variants(&self) -> usize { + self.0.len() + } + + fn assert_invariants(&self) { + assert_ne!(self.num_variants(), 0); + for pv in self.0.values().flat_map(|x| x.iter()) { + pv.assert_invariants(); + } + } +} + +impl PartialSum { + /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns + /// whether `self` has changed. + /// + /// Fails (without mutation) with the conflicting tag if any common rows have different lengths. + pub fn try_join_mut(&mut self, other: Self) -> Result { + for (k, v) in &other.0 { + if self.0.get(k).is_some_and(|row| row.len() != v.len()) { + return Err(*k); + } + } + let mut changed = false; + + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + self.0.insert(k, v); + changed = true; + } + } + Ok(changed) + } + + /// Mutates self according to lattice meet operation (towards `Bottom`). If successful, + /// returns whether `self` has changed. + /// + /// # Errors + /// Fails without mutation, either: + /// * `Some(tag)` if the two [PartialSum]s both had rows with that `tag` but of different lengths + /// * `None` if the two instances had no rows in common (i.e., the result is "Bottom") + pub fn try_meet_mut(&mut self, other: Self) -> Result> { + let mut changed = false; + let mut keys_to_remove = vec![]; + for (k, v) in self.0.iter() { + match other.0.get(k) { + None => keys_to_remove.push(*k), + Some(o_v) => { + if v.len() != o_v.len() { + return Err(Some(*k)); + } + } + } + } + if keys_to_remove.len() == self.0.len() { + return Err(None); + } + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.meet_mut(rhs); + } + } else { + keys_to_remove.push(k); + } + } + for k in keys_to_remove { + self.0.remove(&k); + changed = true; + } + Ok(changed) + } + + /// Whether this sum might have the specified tag + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } + + /// Turns this instance into a [Sum] of some "concrete" value type `C`, + /// *if* this PartialSum has exactly one possible tag. + /// + /// # Errors + /// + /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] + /// supporting the single possible tag with the correct number of elements and no row variables; + /// or if converting a child element failed via [PartialValue::try_into_concrete]. + pub fn try_into_sum(self, typ: &Type) -> Result, ExtractValueError> + where + V: TryInto, + Sum: TryInto, + { + if self.0.len() != 1 { + return Err(ExtractValueError::MultipleVariants(self)); + } + let (tag, v) = self.0.into_iter().exactly_one().unwrap(); + if let TypeEnum::Sum(st) = typ.as_type_enum() { + if let Some(r) = st.get_variant(tag) { + if let Ok(r) = TypeRow::try_from(r.clone()) { + if v.len() == r.len() { + return Ok(Sum { + tag, + values: zip_eq(v, r.iter()) + .map(|(v, t)| v.try_into_concrete(t)) + .collect::, _>>()?, + st: st.clone(), + }); + } + } + } + } + Err(ExtractValueError::BadSumType { + typ: typ.clone(), + tag, + num_elements: v.len(), + }) + } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } +} + +/// An error converting a [PartialValue] or [PartialSum] into a concrete value type +/// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] +#[derive(Clone, Debug, PartialEq, Eq, Error)] +#[allow(missing_docs)] +pub enum ExtractValueError { + #[error("PartialSum value had multiple possible tags: {0}")] + MultipleVariants(PartialSum), + #[error("Value contained `Bottom`")] + ValueIsBottom, + #[error("Value contained `Top`")] + ValueIsTop, + #[error("Could not convert element from abstract value into concrete: {0}")] + CouldNotConvert(V, #[source] VE), + #[error("Could not build Sum from concrete element values")] + CouldNotBuildSum(#[source] SE), + #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] + BadSumType { + typ: Type, + tag: usize, + num_elements: usize, + }, +} + +impl PartialSum { + /// If this Sum might have the specified `tag`, get the elements inside that tag. + pub fn variant_values(&self, variant: usize) -> Option>> { + self.0.get(&variant).cloned() + } +} + +impl PartialOrd for PartialSum { + fn partial_cmp(&self, other: &Self) -> Option { + let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); + let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); + for k in self.0.keys() { + keys1[*k] = 1; + } + + for k in other.0.keys() { + keys2[*k] = 1; + } + + Some(match keys1.cmp(&keys2) { + ord @ Ordering::Greater | ord @ Ordering::Less => ord, + Ordering::Equal => { + for (k, lhs) in &self.0 { + let Some(rhs) = other.0.get(k) else { + unreachable!() + }; + let key_cmp = lhs.partial_cmp(rhs); + if key_cmp != Some(Ordering::Equal) { + return key_cmp; + } + } + Ordering::Equal + } + }) + } +} + +impl std::fmt::Debug for PartialSum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Hash for PartialSum { + fn hash(&self, state: &mut H) { + for (k, v) in &self.0 { + k.hash(state); + v.hash(state); + } + } +} + +/// Wraps some underlying representation (knowledge) of values into a lattice +/// for use in dataflow analysis, including that an instance may be a [PartialSum] +/// of values of the underlying representation +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PartialValue { + /// No possibilities known (so far) + Bottom, + /// A single value (of the underlying representation) + Value(V), + /// Sum (with at least one, perhaps several, possible tags) of underlying values + PartialSum(PartialSum), + /// Might be more than one distinct value of the underlying type `V` + Top, +} + +impl From for PartialValue { + fn from(v: V) -> Self { + Self::Value(v) + } +} + +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { + Self::PartialSum(v) + } +} + +impl PartialValue { + fn assert_invariants(&self) { + if let Self::PartialSum(ps) = self { + ps.assert_invariants(); + } + } + + /// New instance of a sum with a single known tag. + pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::new_variant(tag, values).into() + } + + /// New instance of unit type (i.e. the only possible value, with no contents) + pub fn new_unit() -> Self { + Self::new_variant(0, []) + } +} + +impl PartialValue { + /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. + /// + /// # Panics + /// + /// if the value is believed, for that tag, to have a number of values other than `len` + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + let vals = match self { + PartialValue::Bottom | PartialValue::Value(_) => return None, + PartialValue::PartialSum(ps) => ps.variant_values(tag)?, + PartialValue::Top => vec![PartialValue::Top; len], + }; + assert_eq!(vals.len(), len); + Some(vals) + } + + /// Tells us whether this value might be a Sum with the specified `tag` + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom | PartialValue::Value(_) => false, + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// Turns this instance into some "concrete" value type `C`, *if* it is a single value, + /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by + /// [PartialSum::try_into_sum]. + /// + /// # Errors + /// + /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) + /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is + /// incorrect), or if that [Sum] could not be converted into a `V2`. + pub fn try_into_concrete(self, typ: &Type) -> Result> + where + V: TryInto, + Sum: TryInto, + { + match self { + Self::Value(v) => v + .clone() + .try_into() + .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), + Self::PartialSum(ps) => ps + .try_into_sum(typ)? + .try_into() + .map_err(ExtractValueError::CouldNotBuildSum), + Self::Top => Err(ExtractValueError::ValueIsTop), + Self::Bottom => Err(ExtractValueError::ValueIsBottom), + } + } + + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } +} + +impl TryFrom> for Value { + type Error = ConstTypeError; + + fn try_from(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } +} + +impl Lattice for PartialValue { + fn join_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); + let mut old_self = Self::Top; + std::mem::swap(self, &mut old_self); + let (res, ch) = match (old_self, other) { + (old @ Self::Top, _) | (old, Self::Bottom) => (old, false), + (_, other @ Self::Top) | (Self::Bottom, other) => (other, true), + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { + Some((h3, b)) => (Self::Value(h3), b), + None => (Self::Top, true), + }, + (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { + Ok(ch) => (Self::PartialSum(ps1), ch), + Err(_) => (Self::Top, true), + }, + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + (Self::Top, true) + } + }; + *self = res; + ch + } + + fn meet_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); + let mut old_self = Self::Bottom; + std::mem::swap(self, &mut old_self); + let (res, ch) = match (old_self, other) { + (old @ Self::Bottom, _) | (old, Self::Top) => (old, false), + (_, other @ Self::Bottom) | (Self::Top, other) => (other, true), + (Self::Value(h1), Self::Value(h2)) => match h1.try_meet(h2) { + Some((h3, ch)) => (Self::Value(h3), ch), + None => (Self::Bottom, true), + }, + (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { + Ok(ch) => (Self::PartialSum(ps1), ch), + Err(_) => (Self::Bottom, true), + }, + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + (Self::Bottom, true) + } + }; + *self = res; + ch + } +} + +impl BoundedLattice for PartialValue { + fn top() -> Self { + Self::Top + } + + fn bottom() -> Self { + Self::Bottom + } +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + match (self, other) { + (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), + (Self::Top, Self::Top) => Some(Ordering::Equal), + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), + _ => None, + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use ascent::{lattice::BoundedLattice, Lattice}; + use itertools::{zip_eq, Itertools as _}; + use prop::sample::subsequence; + use proptest::prelude::*; + + use proptest_recurse::{StrategyExt, StrategySet}; + + use super::{AbstractValue, PartialSum, PartialValue}; + + #[derive(Debug, PartialEq, Eq, Clone)] + enum TestSumType { + Branch(Vec>>), + /// None => unit, Some => TestValue <= this *usize* + Leaf(Option), + } + + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct TestValue(usize); + + impl AbstractValue for TestValue {} + + #[derive(Clone)] + struct SumTypeParams { + depth: usize, + desired_size: usize, + expected_branch_size: usize, + } + + impl Default for SumTypeParams { + fn default() -> Self { + Self { + depth: 5, + desired_size: 20, + expected_branch_size: 5, + } + } + } + + impl TestSumType { + fn check_value(&self, pv: &PartialValue) -> bool { + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, + (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), + (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::Branch(sop), PartialValue::PartialSum(ps)) => { + for (k, v) in &ps.0 { + if *k >= sop.len() { + return false; + } + let prod = &sop[*k]; + if prod.len() != v.len() { + return false; + } + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.check_value(rhs)) { + return false; + } + } + true + } + _ => false, + } + } + } + + impl Arbitrary for TestSumType { + type Parameters = SumTypeParams; + type Strategy = SBoxedStrategy; + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { + use proptest::collection::vec; + let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); + let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; + leaf_strat.prop_mutually_recursive( + params.depth as u32, + params.desired_size as u32, + params.expected_branch_size as u32, + set, + move |set| { + let params2 = params.clone(); + vec( + vec( + set.get::(move |set| arb(params2, set)) + .prop_map(Arc::new), + 1..=params.expected_branch_size, + ), + 1..=params.expected_branch_size, + ) + .prop_map(TestSumType::Branch) + .sboxed() + }, + ) + } + + arb(params, &mut StrategySet::default()) + } + } + + fn single_sum_strat( + tag: usize, + elems: Vec>, + ) -> impl Strategy> { + elems + .iter() + .map(Arc::as_ref) + .map(any_partial_value_of_type) + .collect::>() + .prop_map(move |elems| PartialSum::new_variant(tag, elems)) + } + + fn partial_sum_strat( + variants: &[Vec>], + ) -> impl Strategy> { + // We have to clone the `variants` here but only as far as the Vec>> + let tagged_variants = variants.iter().cloned().enumerate().collect::>(); + // The type annotation here (and the .boxed() enabling it) are just for documentation + let sum_variants_strat: BoxedStrategy>> = + subsequence(tagged_variants, 1..=variants.len()) + .prop_flat_map(|selected_variants| { + selected_variants + .into_iter() + .map(|(tag, elems)| single_sum_strat(tag, elems)) + .collect::>() + }) + .boxed(); + sum_variants_strat.prop_map(|psums: Vec>| { + let mut psums = psums.into_iter(); + let first = psums.next().unwrap(); + psums.fold(first, |mut a, b| { + a.try_join_mut(b).unwrap(); + a + }) + }) + } + + fn any_partial_value_of_type( + ust: &TestSumType, + ) -> impl Strategy> { + match ust { + TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), + TestSumType::Leaf(Some(i)) => (0..*i) + .prop_map(TestValue) + .prop_map(PartialValue::from) + .boxed(), + TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), + } + } + + fn any_partial_value_with( + params: ::Parameters, + ) -> impl Strategy> { + any_with::(params).prop_flat_map(|t| any_partial_value_of_type(&t)) + } + + fn any_partial_value() -> impl Strategy> { + any_partial_value_with(Default::default()) + } + + fn any_partial_values() -> impl Strategy; N]> { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(&ust)) + .collect_vec(), + ) + .unwrap() + }) + } + + fn any_typed_partial_value() -> impl Strategy)> { + any::() + .prop_flat_map(|t| any_partial_value_of_type(&t).prop_map(move |v| (t.clone(), v))) + } + + proptest! { + #[test] + fn partial_value_type((tst, pv) in any_typed_partial_value()) { + prop_assert!(tst.check_value(&pv)) + } + + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); + } + + #[test] + fn meet_join_self_noop(v1 in any_partial_value()) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + prop_assert!(meet == v2.clone().meet(v1.clone()), "meet not symmetric"); + prop_assert!(meet == meet.clone().meet(v1.clone()), "repeated meet should be a no-op"); + prop_assert!(meet == meet.clone().meet(v2.clone()), "repeated meet should be a no-op"); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + prop_assert!(join == v2.clone().join(v1.clone()), "join not symmetric"); + prop_assert!(join == join.clone().join(v1.clone()), "repeated join should be a no-op"); + prop_assert!(join == join.clone().join(v2.clone()), "repeated join should be a no-op"); + } + + #[test] + fn lattice_associative([v1, v2, v3] in any_partial_values()) { + let a = v1.clone().meet(v2.clone()).meet(v3.clone()); + let b = v1.clone().meet(v2.clone().meet(v3.clone())); + prop_assert!(a==b, "meet not associative"); + + let a = v1.clone().join(v2.clone()).join(v3.clone()); + let b = v1.clone().join(v2.clone().join(v3.clone())); + prop_assert!(a==b, "join not associative") + } + } +} diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs new file mode 100644 index 000000000..0f4704b42 --- /dev/null +++ b/hugr-passes/src/dataflow/results.rs @@ -0,0 +1,126 @@ +use std::collections::HashMap; + +use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; + +use super::{partial_value::ExtractValueError, AbstractValue, PartialValue, Sum}; + +/// Results of a dataflow analysis, packaged with the Hugr for easy inspection. +/// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). +pub struct AnalysisResults { + pub(super) hugr: H, + pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, + pub(super) case_reachable: Vec<(Node, Node)>, + pub(super) bb_reachable: Vec<(Node, Node)>, + pub(super) out_wire_values: HashMap>, +} + +impl AnalysisResults { + /// Gets the lattice value computed for the given wire + pub fn read_out_wire(&self, w: Wire) -> Option> { + self.out_wire_values.get(&w).cloned() + } + + /// Tells whether a [TailLoop] node can terminate, i.e. whether + /// `Break` and/or `Continue` tags may be returned by the nested DFG. + /// Returns `None` if the specified `node` is not a [TailLoop]. + /// + /// [TailLoop]: hugr_core::ops::TailLoop + pub fn tail_loop_terminates(&self, node: Node) -> Option { + self.hugr.get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr.get_io(node).unwrap(); + Some(TailLoopTermination::from_control_value( + self.in_wire_value + .iter() + .find_map(|(n, p, v)| (*n == out && p.index() == 0).then_some(v)) + .unwrap(), + )) + } + + /// Tells whether a [Case] node is reachable, i.e. whether the predicate + /// to its parent [Conditional] may possibly have the tag corresponding to the [Case]. + /// Returns `None` if the specified `case` is not a [Case], or is not within a [Conditional] + /// (e.g. a [Case]-rooted Hugr). + /// + /// [Case]: hugr_core::ops::Case + /// [Conditional]: hugr_core::ops::Conditional + pub fn case_reachable(&self, case: Node) -> Option { + self.hugr.get_optype(case).as_case()?; + let cond = self.hugr.get_parent(case)?; + self.hugr.get_optype(cond).as_conditional()?; + Some( + self.case_reachable + .iter() + .any(|(cond2, case2)| &cond == cond2 && &case == case2), + ) + } + + /// Tells us if a block ([DataflowBlock] or [ExitBlock]) in a [CFG] is known + /// to be reachable. (Returns `None` if argument is not a child of a CFG.) + /// + /// [CFG]: hugr_core::ops::CFG + /// [DataflowBlock]: hugr_core::ops::DataflowBlock + /// [ExitBlock]: hugr_core::ops::ExitBlock + pub fn bb_reachable(&self, bb: Node) -> Option { + let cfg = self.hugr.get_parent(bb)?; // Not really required...?? + self.hugr.get_optype(cfg).as_cfg()?; + let t = self.hugr.get_optype(bb); + (t.is_dataflow_block() || t.is_exit_block()).then(|| { + self.bb_reachable + .iter() + .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb) + }) + } + + /// Reads a concrete representation of the value on an output wire, if the lattice value + /// computed for the wire can be turned into such. (The lattice value must be either a + /// [PartialValue::Value] or a [PartialValue::PartialSum] with a single possible tag.) + /// + /// # Errors + /// `None` if the analysis did not produce a result for that wire, or if + /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire + /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` + pub fn try_read_wire_concrete( + &self, + w: Wire, + ) -> Result>> + where + V2: TryFrom + TryFrom, Error = SE>, + { + let v = self.read_out_wire(w).ok_or(None)?; + let (_, typ) = self + .hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .ok_or(None)?; + v.try_into_concrete(&typ).map_err(Some) + } +} + +/// Tells whether a loop iterates (never, always, sometimes) +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +pub enum TailLoopTermination { + /// The loop never exits (is an infinite loop); no value is ever + /// returned out of the loop. (aka, Bottom.) + // TODO what about a loop that never exits OR continues because of a nested infinite loop? + NeverBreaks, + /// The loop never iterates (so is equivalent to a [DFG](hugr_core::ops::DFG), + /// modulo untupling of the control value) + NeverContinues, + /// The loop might iterate and/or exit. (aka, Top) + BreaksAndContinues, +} + +impl TailLoopTermination { + fn from_control_value(v: &PartialValue) -> Self { + let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); + if may_break { + if may_continue { + Self::BreaksAndContinues + } else { + Self::NeverContinues + } + } else { + Self::NeverBreaks + } + } +} diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs new file mode 100644 index 000000000..13815d186 --- /dev/null +++ b/hugr-passes/src/dataflow/test.rs @@ -0,0 +1,548 @@ +use ascent::{lattice::BoundedLattice, Lattice}; + +use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; +use hugr_core::extension::PRELUDE_REGISTRY; +use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; +use hugr_core::ops::handle::DfgID; +use hugr_core::ops::TailLoop; +use hugr_core::types::TypeRow; +use hugr_core::{ + builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, + extension::{ + prelude::{bool_t, UnpackTuple}, + ExtensionSet, EMPTY_REG, + }, + ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, + type_row, + types::{Signature, SumType, Type}, + HugrView, +}; +use hugr_core::{Hugr, Wire}; +use rstest::{fixture, rstest}; + +use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; + +// ------- Minimal implementation of DFContext and AbstractValue ------- +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum Void {} + +impl AbstractValue for Void {} + +struct TestContext; + +impl ConstLoader for TestContext {} +impl DFContext for TestContext {} + +// This allows testing creation of tuple/sum Values (only) +impl From for Value { + fn from(v: Void) -> Self { + match v {} + } +} + +fn pv_false() -> PartialValue { + PartialValue::new_variant(0, []) +} + +fn pv_true() -> PartialValue { + PartialValue::new_variant(1, []) +} + +fn pv_true_or_false() -> PartialValue { + pv_true().join(pv_false()) +} + +#[test] +fn test_make_tuple() { + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let x: Value = results.try_read_wire_concrete(v3).unwrap(); + assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); +} + +#[test] +fn test_unpack_tuple_const() { + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); + let v = builder.add_load_value(Value::tuple([Value::false_val(), Value::true_val()])); + let [o1, o2] = builder + .add_dataflow_op(UnpackTuple::new(vec![bool_t(); 2].into()), [v]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o1_r: Value = results.try_read_wire_concrete(o1).unwrap(); + assert_eq!(o1_r, Value::false_val()); + let o2_r: Value = results.try_read_wire_concrete(o2).unwrap(); + assert_eq!(o2_r, Value::true_val()); +} + +#[test] +fn test_tail_loop_never_iterates() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let r_v = Value::unit_sum(3, 6).unwrap(); + let r_w = builder.add_load_value(r_v.clone()); + let tag = Tag::new( + TailLoop::BREAK_TAG, + vec![type_row![], r_v.get_type().into()], + ); + let tagged = builder.add_dataflow_op(tag, [r_w]).unwrap(); + + let tlb = builder + .tail_loop_builder([], [], vec![r_v.get_type()].into()) + .unwrap(); + let tail_loop = tlb.finish_with_outputs(tagged.out_wire(0), []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o_r: Value = results.try_read_wire_concrete(tl_o).unwrap(); + assert_eq!(o_r, r_v); + assert_eq!( + Some(TailLoopTermination::NeverContinues), + results.tail_loop_terminates(tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_always_iterates() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let r_w = builder.add_load_value( + Value::sum( + TailLoop::CONTINUE_TAG, + [], + SumType::new([type_row![], bool_t().into()]), + ) + .unwrap(), + ); + let true_w = builder.add_load_value(Value::true_val()); + + let tlb = builder + .tail_loop_builder([], [(bool_t(), true_w)], vec![bool_t()].into()) + .unwrap(); + + // r_w has tag 0, so we always continue; + // we put true in our "other_output", but we should not propagate this to + // output because r_w never supports 1. + let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap(); + + let [tl_o1, tl_o2] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o_r1 = results.read_out_wire(tl_o1).unwrap(); + assert_eq!(o_r1, PartialValue::bottom()); + let o_r2 = results.read_out_wire(tl_o2).unwrap(); + assert_eq!(o_r2, PartialValue::bottom()); + assert_eq!( + Some(TailLoopTermination::NeverBreaks), + results.tail_loop_terminates(tail_loop.node()) + ); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); +} + +#[test] +fn test_tail_loop_two_iters() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let tlb = builder + .tail_loop_builder_exts( + [], + [(bool_t(), false_w), (bool_t(), true_w)], + type_row![], + ExtensionSet::new(), + ) + .unwrap(); + assert_eq!( + tlb.loop_signature().unwrap().signature(), + Signature::new_endo(vec![bool_t(); 2]) + ); + let [in_w1, in_w2] = tlb.input_wires_arr(); + let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o_r1 = results.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, pv_true_or_false()); + let o_r2 = results.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, pv_true_or_false()); + assert_eq!( + Some(TailLoopTermination::BreaksAndContinues), + results.tail_loop_terminates(tail_loop.node()) + ); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); +} + +#[test] +fn test_tail_loop_containing_conditional() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let control_variants = vec![vec![bool_t(); 2].into(); 2]; + let control_t = Type::new_sum(control_variants.clone()); + let body_out_variants = vec![TypeRow::from(control_t.clone()), vec![bool_t(); 2].into()]; + + let init = builder.add_load_value( + Value::sum( + 0, + [Value::false_val(), Value::true_val()], + SumType::new(control_variants.clone()), + ) + .unwrap(), + ); + + let mut tlb = builder + .tail_loop_builder([(control_t, init)], [], vec![bool_t(); 2].into()) + .unwrap(); + let tl = tlb.loop_signature().unwrap().clone(); + let [in_w] = tlb.input_wires_arr(); + + // Branch on in_wire, so first iter 0(false, true)... + let mut cond = tlb + .conditional_builder( + (control_variants.clone(), in_w), + [], + Type::new_sum(body_out_variants.clone()).into(), + ) + .unwrap(); + let mut case0_b = cond.case_builder(0).unwrap(); + let [a, b] = case0_b.input_wires_arr(); + // Builds value for next iter as 1(true, false) by flipping arguments + let [next_input] = case0_b + .add_dataflow_op(Tag::new(1, control_variants), [b, a]) + .unwrap() + .outputs_arr(); + let cont = case0_b.make_continue(tl.clone(), [next_input]).unwrap(); + case0_b.finish_with_outputs([cont]).unwrap(); + // Second iter 1(true, false) => exit with (true, false) + let mut case1_b = cond.case_builder(1).unwrap(); + let loop_res = case1_b.make_break(tl, case1_b.input_wires()).unwrap(); + case1_b.finish_with_outputs([loop_res]).unwrap(); + let [r] = cond.finish_sub_container().unwrap().outputs_arr(); + + let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o_r1 = results.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, pv_true()); + let o_r2 = results.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, pv_false()); + assert_eq!( + Some(TailLoopTermination::BreaksAndContinues), + results.tail_loop_terminates(tail_loop.node()) + ); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); +} + +#[test] +fn test_conditional() { + let variants = vec![type_row![], type_row![], bool_t().into()]; + let cond_t = Type::new_sum(variants.clone()); + let mut builder = DFGBuilder::new(Signature::new(cond_t, type_row![])).unwrap(); + let [arg_w] = builder.input_wires_arr(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let mut cond_builder = builder + .conditional_builder( + (variants, arg_w), + [(bool_t(), true_w)], + vec![bool_t(); 2].into(), + ) + .unwrap(); + // will be unreachable + let case1_b = cond_builder.case_builder(0).unwrap(); + let case1 = case1_b.finish_with_outputs([false_w, false_w]).unwrap(); + + let case2_b = cond_builder.case_builder(1).unwrap(); + let [c2a] = case2_b.input_wires_arr(); + let case2 = case2_b.finish_with_outputs([false_w, c2a]).unwrap(); + + let case3_b = cond_builder.case_builder(2).unwrap(); + let [c3_1, _c3_2] = case3_b.input_wires_arr(); + let case3 = case3_b.finish_with_outputs([c3_1, false_w]).unwrap(); + + let cond = cond_builder.finish_sub_container().unwrap(); + + let [cond_o1, cond_o2] = cond.outputs_arr(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let arg_pv = PartialValue::new_variant(1, []).join(PartialValue::new_variant( + 2, + [PartialValue::new_variant(0, [])], + )); + let results = Machine::new(&hugr).run(TestContext, [(0.into(), arg_pv)]); + + let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); + assert_eq!(cond_r1, Value::false_val()); + assert!(results + .try_read_wire_concrete::(cond_o2) + .is_err()); + + assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only + assert_eq!(results.case_reachable(case2.node()), Some(true)); + assert_eq!(results.case_reachable(case3.node()), Some(true)); + assert_eq!(results.case_reachable(cond.node()), None); +} + +// A Hugr being a function on bools: (x, y) => (x XOR y, x AND y) +#[fixture] +fn xor_and_cfg() -> Hugr { + // Entry branch on first arg, passes arguments on unchanged + // /T F\ + // A --T-> B A(x=true, y) branch on second arg, passing (first arg == true, false) + // \F / B(w,v) => X(v,w) + // > X < + // Inputs received: + // Entry A B X + // F,F - F,F F,F + // F,T - F,T T,F + // T,F T,F - T,F + // T,T T,T T,F F,T + let mut builder = + CFGBuilder::new(Signature::new(vec![bool_t(); 2], vec![bool_t(); 2])).unwrap(); + + // entry (x, y) => (if x then A else B)(x=true, y) + let entry = builder + .entry_builder(vec![type_row![]; 2], vec![bool_t(); 2].into()) + .unwrap(); + let [in_x, in_y] = entry.input_wires_arr(); + let entry = entry.finish_with_outputs(in_x, [in_x, in_y]).unwrap(); + + // A(x==true, y) => (if y then B else X)(x, false) + let mut a = builder + .block_builder( + vec![bool_t(); 2].into(), + vec![type_row![]; 2], + vec![bool_t(); 2].into(), + ) + .unwrap(); + let [in_x, in_y] = a.input_wires_arr(); + let false_w1 = a.add_load_value(Value::false_val()); + let a = a.finish_with_outputs(in_y, [in_x, false_w1]).unwrap(); + + // B(w, v) => X(v, w) + let mut b = builder + .block_builder( + vec![bool_t(); 2].into(), + [type_row![]], + vec![bool_t(); 2].into(), + ) + .unwrap(); + let [in_w, in_v] = b.input_wires_arr(); + let [control] = b + .add_dataflow_op(Tag::new(0, vec![type_row![]]), []) + .unwrap() + .outputs_arr(); + let b = b.finish_with_outputs(control, [in_v, in_w]).unwrap(); + + let x = builder.exit_block(); + + let [fals, tru]: [usize; 2] = [0, 1]; + builder.branch(&entry, tru, &a).unwrap(); // if true + builder.branch(&entry, fals, &b).unwrap(); // if false + builder.branch(&a, tru, &b).unwrap(); // if true + builder.branch(&a, fals, &x).unwrap(); // if false + builder.branch(&b, 0, &x).unwrap(); + builder.finish_hugr(&EMPTY_REG).unwrap() +} + +#[rstest] +#[case(pv_true(), pv_true(), pv_false(), pv_true())] +#[case(pv_true(), pv_false(), pv_true(), pv_false())] +#[case(pv_true(), pv_true_or_false(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true(), PartialValue::Top, pv_true_or_false(), pv_true_or_false())] +#[case(pv_false(), pv_true(), pv_true(), pv_false())] +#[case(pv_false(), pv_false(), pv_false(), pv_false())] +#[case(pv_false(), pv_true_or_false(), pv_true_or_false(), pv_false())] +#[case(pv_false(), PartialValue::Top, PartialValue::Top, pv_false())] // if !inp0 then out0=inp1 +#[case(pv_true_or_false(), pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false(), pv_true_or_false())] +#[case(PartialValue::Top, pv_true(), pv_true_or_false(), PartialValue::Top)] +#[case(PartialValue::Top, pv_false(), PartialValue::Top, PartialValue::Top)] +fn test_cfg( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] out0: PartialValue, + #[case] out1: PartialValue, + xor_and_cfg: Hugr, +) { + let root = xor_and_cfg.root(); + let results = Machine::new(&xor_and_cfg).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]); + + assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0); + assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1); +} + +#[rstest] +#[case(pv_true(), pv_true(), pv_true())] +#[case(pv_false(), pv_false(), pv_false())] +#[case(pv_true(), pv_false(), pv_true_or_false())] // Two calls alias +fn test_call( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] out: PartialValue, +) { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap(); + let func_bldr = builder + .define_function("id", Signature::new_endo(bool_t())) + .unwrap(); + let [v] = func_bldr.input_wires_arr(); + let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); + let [a, b] = builder.input_wires_arr(); + let [a2] = builder + .call(func_defn.handle(), &[], [a], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let [b2] = builder + .call(func_defn.handle(), &[], [b], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let hugr = builder + .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) + .unwrap(); + + let results = Machine::new(&hugr).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]); + + let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); + // The two calls alias so both results will be the same: + assert_eq!(res0, out); + assert_eq!(res1, out); +} + +#[test] +fn test_region() { + let mut builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap(); + let [in_w] = builder.input_wires_arr(); + let cst_w = builder.add_load_const(Value::false_val()); + let nested = builder + .dfg_builder(Signature::new_endo(vec![bool_t(); 2]), [in_w, cst_w]) + .unwrap(); + let nested_ins = nested.input_wires(); + let nested = nested.finish_with_outputs(nested_ins).unwrap(); + let hugr = builder + .finish_prelude_hugr_with_outputs(nested.outputs()) + .unwrap(); + let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); + let whole_hugr_results = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(pv_false()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 1)), + Some(pv_false()) + ); + + let subview = DescendantsGraph::::try_new(&hugr, nested.node()).unwrap(); + // Do not provide a value on the second input (constant false in the whole hugr, above) + let sub_hugr_results = Machine::new(subview).run(TestContext, [(0.into(), pv_true())]); + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(PartialValue::Top) + ); + for w in [0, 1] { + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(hugr.root(), w)), + None + ); + } +} + +#[test] +fn test_module() { + let mut modb = ModuleBuilder::new(); + let leaf_fn = modb + .define_function("leaf", Signature::new_endo(vec![bool_t(); 2])) + .unwrap(); + let outs = leaf_fn.input_wires(); + let leaf_fn = leaf_fn.finish_with_outputs(outs).unwrap(); + + let mut f2 = modb + .define_function("f2", Signature::new(bool_t(), vec![bool_t(); 2])) + .unwrap(); + let [inp] = f2.input_wires_arr(); + let cst_true = f2.add_load_value(Value::true_val()); + let f2_call = f2 + .call(leaf_fn.handle(), &[], [inp, cst_true], &EMPTY_REG) + .unwrap(); + let f2 = f2.finish_with_outputs(f2_call.outputs()).unwrap(); + + let mut main = modb + .define_function("main", Signature::new(bool_t(), vec![bool_t(); 2])) + .unwrap(); + let [inp] = main.input_wires_arr(); + let cst_false = main.add_load_value(Value::false_val()); + let main_call = main + .call(leaf_fn.handle(), &[], [inp, cst_false], &EMPTY_REG) + .unwrap(); + main.finish_with_outputs(main_call.outputs()).unwrap(); + let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); + let [f2_inp, _] = hugr.get_io(f2.node()).unwrap(); + + let results_just_main = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]); + assert_eq!( + results_just_main.read_out_wire(Wire::new(f2_inp, 0)), + Some(PartialValue::Bottom) + ); + for call in [f2_call, main_call] { + // The first output of the Call comes from `main` because no value was fed in from f2 + assert_eq!( + results_just_main.read_out_wire(Wire::new(call.node(), 0)), + Some(pv_true()) + ); + // (Without reachability) the second output of the Call is the join of the two constant inputs from the two calls + assert_eq!( + results_just_main.read_out_wire(Wire::new(call.node(), 1)), + Some(pv_true_or_false()) + ); + } + + let results_two_calls = { + let mut m = Machine::new(&hugr); + m.prepopulate_df_inputs(f2.node(), [(0.into(), pv_true())]); + m.run(TestContext, [(0.into(), pv_false())]) + }; + + for call in [f2_call, main_call] { + assert_eq!( + results_two_calls.read_out_wire(Wire::new(call.node(), 0)), + Some(pv_true_or_false()) + ); + assert_eq!( + results_two_calls.read_out_wire(Wire::new(call.node(), 1)), + Some(pv_true_or_false()) + ); + } +} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs new file mode 100644 index 000000000..50cf10318 --- /dev/null +++ b/hugr-passes/src/dataflow/value_row.rs @@ -0,0 +1,103 @@ +// Wrap a (known-length) row of values into a lattice. + +use std::{ + cmp::Ordering, + ops::{Index, IndexMut}, +}; + +use ascent::{lattice::BoundedLattice, Lattice}; +use itertools::zip_eq; + +use super::{AbstractValue, PartialValue}; + +#[derive(PartialEq, Clone, Debug, Eq, Hash)] +pub(super) struct ValueRow(Vec>); + +impl ValueRow { + pub fn new(len: usize) -> Self { + Self(vec![PartialValue::bottom(); len]) + } + + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + *self.0.get_mut(idx).unwrap() = v; + self + } + + pub fn singleton(v: PartialValue) -> Self { + Self(vec![v]) + } + + /// The first value in this ValueRow must be a sum; + /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, + /// then appending the rest of the values in this row. + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option>> { + let vals = self[0].variant_values(variant, len)?; + Some(vals.into_iter().chain(self.0[1..].to_owned())) + } +} + +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PartialValue; + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec>: Index, +{ + type Output = > as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +impl IndexMut for ValueRow +where + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 9042850d4..a2208430d 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod dataflow; pub mod force_order; mod half_node; pub mod lower;