Skip to content

Commit

Permalink
feat: Lists and extension sets with splicing (#1657)
Browse files Browse the repository at this point in the history
This PR allows lists and extension sets in `hugr-model` to splice lists
and extension sets, e.g. `[0 xs ... 1 2 3]`. This is used to import and
export rows and extension sets with variables. Closes #1609.
  • Loading branch information
zrho authored Nov 27, 2024
1 parent 6a75f4c commit 344a7e4
Show file tree
Hide file tree
Showing 15 changed files with 367 additions and 249 deletions.
167 changes: 69 additions & 98 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,48 +509,24 @@ impl<'a> Context<'a> {
/// like for the other nodes since the ports are control flow ports.
pub fn export_block_signature(&mut self, block: &DataflowBlock) -> model::TermId {
let inputs = {
let mut inputs = BumpVec::with_capacity_in(block.inputs.len(), self.bump);
for input in block.inputs.iter() {
inputs.push(self.export_type(input));
}
let inputs = self.make_term(model::Term::List {
items: inputs.into_bump_slice(),
tail: None,
});
let inputs = self.export_type_row(&block.inputs);
let inputs = self.make_term(model::Term::Control { values: inputs });
self.make_term(model::Term::List {
items: self.bump.alloc_slice_copy(&[inputs]),
tail: None,
parts: self.bump.alloc_slice_copy(&[model::ListPart::Item(inputs)]),
})
};

let tail = {
let mut tail = BumpVec::with_capacity_in(block.other_outputs.len(), self.bump);
for other_output in block.other_outputs.iter() {
tail.push(self.export_type(other_output));
}
self.make_term(model::Term::List {
items: tail.into_bump_slice(),
tail: None,
})
};
let tail = self.export_type_row(&block.other_outputs);

let outputs = {
let mut outputs = BumpVec::with_capacity_in(block.sum_rows.len(), self.bump);
for sum_row in block.sum_rows.iter() {
let mut variant = BumpVec::with_capacity_in(sum_row.len(), self.bump);
for typ in sum_row.iter() {
variant.push(self.export_type(typ));
}
let variant = self.make_term(model::Term::List {
items: variant.into_bump_slice(),
tail: Some(tail),
});
outputs.push(self.make_term(model::Term::Control { values: variant }));
let variant = self.export_type_row_with_tail(sum_row, Some(tail));
let control = self.make_term(model::Term::Control { values: variant });
outputs.push(model::ListPart::Item(control));
}
self.make_term(model::Term::List {
items: outputs.into_bump_slice(),
tail: None,
parts: outputs.into_bump_slice(),
})
};

Expand Down Expand Up @@ -772,10 +748,12 @@ impl<'a> Context<'a> {
TypeArg::String { arg } => self.make_term(model::Term::Str(self.bump.alloc_str(arg))),
TypeArg::Sequence { elems } => {
// For now we assume that the sequence is meant to be a list.
let items = self
.bump
.alloc_slice_fill_iter(elems.iter().map(|elem| self.export_type_arg(elem)));
self.make_term(model::Term::List { items, tail: None })
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
.map(|elem| model::ListPart::Item(self.export_type_arg(elem))),
);
self.make_term(model::Term::List { parts })
}
TypeArg::Extensions { es } => self.export_ext_set(es),
TypeArg::Variable { v } => self.export_type_arg_var(v),
Expand All @@ -798,32 +776,53 @@ impl<'a> Context<'a> {
pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId {
match t {
SumType::Unit { size } => {
let items = self.bump.alloc_slice_fill_iter((0..*size).map(|_| {
self.make_term(model::Term::List {
items: &[],
tail: None,
})
let parts = self.bump.alloc_slice_fill_iter((0..*size).map(|_| {
model::ListPart::Item(self.make_term(model::Term::List { parts: &[] }))
}));
let list = model::Term::List { items, tail: None };
let variants = self.make_term(list);
let variants = self.make_term(model::Term::List { parts });
self.make_term(model::Term::Adt { variants })
}
SumType::General { rows } => {
let items = self
.bump
.alloc_slice_fill_iter(rows.iter().map(|row| self.export_type_row(row)));
let list = model::Term::List { items, tail: None };
let parts = self.bump.alloc_slice_fill_iter(
rows.iter()
.map(|row| model::ListPart::Item(self.export_type_row(row))),
);
let list = model::Term::List { parts };
let variants = { self.make_term(list) };
self.make_term(model::Term::Adt { variants })
}
}
}

pub fn export_type_row<RV: MaybeRV>(&mut self, t: &TypeRowBase<RV>) -> model::TermId {
let mut items = BumpVec::with_capacity_in(t.len(), self.bump);
items.extend(t.iter().map(|row| self.export_type(row)));
let items = items.into_bump_slice();
self.make_term(model::Term::List { items, tail: None })
#[inline]
pub fn export_type_row<RV: MaybeRV>(&mut self, row: &TypeRowBase<RV>) -> model::TermId {
self.export_type_row_with_tail(row, None)
}

pub fn export_type_row_with_tail<RV: MaybeRV>(
&mut self,
row: &TypeRowBase<RV>,
tail: Option<model::TermId>,
) -> model::TermId {
let mut parts = BumpVec::with_capacity_in(row.len() + tail.is_some() as usize, self.bump);

for t in row.iter() {
match t.as_type_enum() {
TypeEnum::RowVar(var) => {
parts.push(model::ListPart::Splice(self.export_row_var(var.as_rv())));
}
_ => {
parts.push(model::ListPart::Item(self.export_type(t)));
}
}
}

if let Some(tail) = tail {
parts.push(model::ListPart::Splice(tail));
}

let parts = parts.into_bump_slice();
self.make_term(model::Term::List { parts })
}

/// Exports a `TypeParam` to a term.
Expand Down Expand Up @@ -855,12 +854,12 @@ impl<'a> Context<'a> {
self.make_term(model::Term::ListType { item_type })
}
TypeParam::Tuple { params } => {
let items = self.bump.alloc_slice_fill_iter(
let parts = self.bump.alloc_slice_fill_iter(
params
.iter()
.map(|param| self.export_type_param(param, None)),
.map(|param| model::ListPart::Item(self.export_type_param(param, None))),
);
let types = self.make_term(model::Term::List { items, tail: None });
let types = self.make_term(model::Term::List { parts });
self.make_term(model::Term::ApplyFull {
global: model::GlobalRef::Named(TERM_PARAM_TUPLE),
args: self.bump.alloc_slice_copy(&[types]),
Expand All @@ -873,54 +872,26 @@ impl<'a> Context<'a> {
}
}

pub fn export_ext_set(&mut self, t: &ExtensionSet) -> model::TermId {
// Extension sets with variables are encoded using a hack: a variable in the
// extension set is represented by converting its index into a string.
// Until we have a better representation for extension sets, we therefore
// need to try and parse each extension as a number to determine if it is
// a variable or an extension.

// NOTE: This overprovisions the capacity since some of the entries of the row
// may be variables. Since we panic when there is more than one variable, this
// may at most waste one slot. That is way better than having to allocate
// a temporary vector.
//
// Also `ExtensionSet` has no way of reporting its size, so we have to count
// the elements by iterating over them...
let capacity = t.iter().count();
let mut extensions = BumpVec::with_capacity_in(capacity, self.bump);
let mut rest = None;

for ext in t.iter() {
if let Ok(index) = ext.parse::<usize>() {
// Extension sets in the model support at most one variable. This is a
// deliberate limitation so that extension sets behave like polymorphic rows.
// The type theory of such rows and how to apply them to model (co)effects
// is well understood.
//
// Extension sets in `hugr-core` at this point have no such restriction.
// However, it appears that so far we never actually use extension sets with
// multiple variables, except for extension sets that are generated through
// property testing.
if rest.is_some() {
// TODO: We won't need this anymore once we have a core representation
// that ensures that extension sets have at most one variable.
panic!("Extension set with multiple variables")
}
pub fn export_ext_set(&mut self, ext_set: &ExtensionSet) -> model::TermId {
let capacity = ext_set.iter().size_hint().0;
let mut parts = BumpVec::with_capacity_in(capacity, self.bump);

let node = self.local_scope.expect("local variable out of scope");
rest = Some(
self.module
.insert_term(model::Term::Var(model::LocalRef::Index(node, index as _))),
);
} else {
extensions.push(self.bump.alloc_str(ext) as &str);
for ext in ext_set.iter() {
// `ExtensionSet`s represent variables by extension names that parse to integers.
match ext.parse::<u16>() {
Ok(var) => {
let node = self.local_scope.expect("local variable out of scope");
let local_ref = model::LocalRef::Index(node, var);
let term = self.make_term(model::Term::Var(local_ref));
parts.push(model::ExtSetPart::Splice(term));
}
Err(_) => parts.push(model::ExtSetPart::Extension(self.bump.alloc_str(ext))),
}
}

let extensions = extensions.into_bump_slice();

self.make_term(model::Term::ExtSet { extensions, rest })
self.make_term(model::Term::ExtSet {
parts: parts.into_bump_slice(),
})
}

pub fn export_node_metadata(
Expand Down
Loading

0 comments on commit 344a7e4

Please sign in to comment.