Skip to content

Commit

Permalink
inference tests
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Jun 14, 2024
1 parent cd45bc9 commit f3a05c4
Showing 1 changed file with 161 additions and 8 deletions.
169 changes: 161 additions & 8 deletions hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,6 @@ impl Hugr {
infer(self, self.root(), remove)?;
Ok(())
}

// Note: tests
// * all combinations of (remove or not, TO_BE_INFERRED present or absent, success(inferred-set) or failure (possible only if no TO_BE_INFERRED) )
// * parent - child - grandchild tests:
// X - Y + INFER - X (ok with remove, but fails w/out remove)
// X - Y + INFER - Y or X - INFER - Y (mid fails against parent with just Y, regardless of remove)
// X - INFER - X (ok with-or-without remove)
}

/// Internal API for HUGRs, not intended for use by users.
Expand Down Expand Up @@ -285,7 +278,13 @@ pub enum HugrError {

#[cfg(test)]
mod test {
use super::{Hugr, HugrView};
#[cfg(feature = "extension_inference")]
use super::ValidationError;
use super::{ExtensionError, Hugr, HugrMut, HugrView, Node};
use crate::extension::{ExtensionId, ExtensionSet, EMPTY_REG, TO_BE_INFERRED};
use crate::types::{FunctionType, Type};
use crate::{const_extension_ids, ops, type_row};
use rstest::rstest;

#[test]
fn impls_send_and_sync() {
Expand All @@ -304,4 +303,158 @@ mod test {
let hugr = simple_dfg_hugr();
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

const_extension_ids! {
const XA: ExtensionId = "EXT_A";
const XB: ExtensionId = "EXT_B";
}

#[rstest]
#[case([], XA.into())]
#[case([XA], XA.into())]
#[case([XB], ExtensionSet::from_iter([XA, XB]))]

fn infer_single_delta(
#[case] parent: impl IntoIterator<Item = ExtensionId>,
#[values(true, false)] remove: bool, // makes no difference when inferring
#[case] result: ExtensionSet,
) {
let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into());
let (mut h, _) = build_ext_cfg(parent);
h.infer_extensions(remove).unwrap();
assert_eq!(h, build_ext_cfg(result).0);
}

#[test]
fn infer_removes_from_delta() {
let parent = ExtensionSet::from_iter([XA, XB]);
let mut h = build_ext_cfg(parent.clone()).0;
let backup = h.clone();
h.infer_extensions(false).unwrap();
assert_eq!(h, backup); // did nothing
h.infer_extensions(true).unwrap();
assert_eq!(h, build_ext_cfg(XA.into()).0);
}

#[test]
fn infer_bad_remove() {
let (mut h, mid) = build_ext_cfg(XB.into());
let backup = h.clone();
h.infer_extensions(false).unwrap();
assert_eq!(h, backup); // did nothing
let val_res = h.validate(&EMPTY_REG);
let expected_err = ExtensionError {
parent: h.root(),
parent_extensions: XB.into(),
child: mid,
child_extensions: XA.into(),
};
#[cfg(feature = "extension_inference")]
assert_eq!(
val_res,
Err(ValidationError::ExtensionError(expected_err.clone()))
);
#[cfg(not(feature = "extension_inference"))]
assert!(val_res.is_ok());

let inf_res = h.infer_extensions(true);
assert_eq!(inf_res, Err(expected_err));
}

fn build_ext_cfg(parent: ExtensionSet) -> (Hugr, Node) {
let ty = Type::new_function(FunctionType::new_endo(type_row![]));
let mut h = Hugr::new(ops::DFG {
signature: FunctionType::new_endo(ty.clone()).with_extension_delta(parent.clone()),
});
let root = h.root();
let mid = add_inliftout(&mut h, root, ty);
(h, mid)
}

fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node {
let inp = h.add_node_with_parent(
p,
ops::Input {
types: ty.clone().into(),
},
);
let out = h.add_node_with_parent(
p,
ops::Output {
types: ty.clone().into(),
},
);
let mid = h.add_node_with_parent(
p,
ops::Lift {
type_row: ty.into(),
new_extension: XA,
},
);
h.connect(inp, 0, mid, 0);
h.connect(mid, 0, out, 0);
mid
}

#[rstest]
#[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])]
#[case([XA], [TO_BE_INFERRED], true, [XA])]
#[case([XB], [TO_BE_INFERRED], false, [XA])]
#[case([XB], [XA, TO_BE_INFERRED], false, [XA])]
#[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])]
#[case([XA, XB], [TO_BE_INFERRED], true, [XA])]
#[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])]
// This one just tests removal:
#[case([XA], [XA, XB], true, [XA])]
// TODO: Consider adding a separate expected-grandparent so we can have something different?
fn infer_three_generations(
#[case] grandparent: impl IntoIterator<Item = ExtensionId>,
#[case] parent: impl IntoIterator<Item = ExtensionId>,
#[case] success: bool,
#[case] result: impl IntoIterator<Item = ExtensionId>,
) {
let ty = Type::new_function(FunctionType::new_endo(type_row![]));
let grandparent = ExtensionSet::from_iter(grandparent);
let result = ExtensionSet::from_iter(result);
let root_ty = ops::Conditional {
sum_rows: vec![type_row![]],
other_inputs: ty.clone().into(),
outputs: ty.clone().into(),
extension_delta: grandparent.clone(),
};
let mut h = Hugr::new(root_ty.clone());
let p = h.add_node_with_parent(
h.root(),
ops::Case {
signature: FunctionType::new_endo(ty.clone())
.with_extension_delta(ExtensionSet::from_iter(parent)),
},
);
add_inliftout(&mut h, p, ty.clone());
assert!(h.validate_extensions().is_err());
let inf_res = h.infer_extensions(true);
if success {
assert!(inf_res.is_ok());
let expected_p = ops::Case {
signature: FunctionType::new_endo(ty).with_extension_delta(result.clone()),
};
assert!(h.get_optype(p) == &expected_p.into());
let expected_gp = ops::Conditional {
extension_delta: result,
..root_ty
};
assert!(h.root_type() == &expected_gp.into())
// rest should be unchanged...
} else {
assert_eq!(
inf_res,
Err(ExtensionError {
parent: h.root(),
parent_extensions: grandparent,
child: p,
child_extensions: result
})
);
}
}
}

0 comments on commit f3a05c4

Please sign in to comment.