diff --git a/src/unify/backing_vec.rs b/src/unify/backing_vec.rs index 7c1eb2c..bac9556 100644 --- a/src/unify/backing_vec.rs +++ b/src/unify/backing_vec.rs @@ -6,7 +6,7 @@ use std::ops::{self, Range}; use undo_log::{Rollback, Snapshots, UndoLogs, VecLog}; -use super::{UnifyKey, UnifyValue, VarValue}; +use super::{ExtraTraversalData, NoExtraTraversalData, UnifyKey, UnifyValue, VarValue}; #[allow(dead_code)] // rustc BUG #[allow(type_alias_bounds)] @@ -15,9 +15,12 @@ type Key = ::Key; /// Largely internal trait implemented by the unification table /// backing store types. The most common such type is `InPlace`, /// which indicates a standard, mutable unification table. -pub trait UnificationStoreBase: ops::Index>> { +pub trait UnificationStoreBase: + ops::Index, Self::ExtraTraversalData>> +{ type Key: UnifyKey; type Value: UnifyValue; + type ExtraTraversalData: ExtraTraversalData; fn len(&self) -> usize; @@ -27,15 +30,18 @@ pub trait UnificationStoreBase: ops::Index>> } pub trait UnificationStoreMut: UnificationStoreBase { - fn reset_unifications(&mut self, value: impl FnMut(u32) -> VarValue); + fn reset_unifications( + &mut self, + value: impl FnMut(u32) -> VarValue, + ); - fn push(&mut self, value: VarValue); + fn push(&mut self, value: VarValue); fn reserve(&mut self, num_new_values: usize); fn update(&mut self, index: usize, op: F) where - F: FnOnce(&mut VarValue); + F: FnOnce(&mut VarValue); } pub trait UnificationStore: UnificationStoreMut { @@ -55,14 +61,21 @@ pub trait UnificationStore: UnificationStoreMut { #[derive(Clone, Debug)] pub struct InPlace< K: UnifyKey, - V: sv::VecLike> = Vec>, - L = VecLog>>, + TD: ExtraTraversalData = NoExtraTraversalData, + V: sv::VecLike> = Vec>, + L = VecLog>>, > { - pub(crate) values: sv::SnapshotVec, V, L>, + pub(crate) values: sv::SnapshotVec, V, L>, } // HACK(eddyb) manual impl avoids `Default` bound on `K`. -impl> + Default, L: Default> Default for InPlace { +impl< + K: UnifyKey, + TD: ExtraTraversalData, + V: sv::VecLike> + Default, + L: Default, + > Default for InPlace +{ fn default() -> Self { InPlace { values: sv::SnapshotVec::new(), @@ -70,32 +83,38 @@ impl> + Default, L: Default> Default for } } -impl UnificationStoreBase for InPlace +impl UnificationStoreBase for InPlace where K: UnifyKey, - V: sv::VecLike>, + V: sv::VecLike>, + TD: ExtraTraversalData, { type Key = K; type Value = K::Value; + type ExtraTraversalData = TD; fn len(&self) -> usize { self.values.len() } } -impl UnificationStoreMut for InPlace +impl UnificationStoreMut for InPlace where K: UnifyKey, - V: sv::VecLike>, - L: UndoLogs>>, + TD: ExtraTraversalData, + V: sv::VecLike>, + L: UndoLogs>>, { #[inline] - fn reset_unifications(&mut self, mut value: impl FnMut(u32) -> VarValue) { + fn reset_unifications( + &mut self, + mut value: impl FnMut(u32) -> VarValue, + ) { self.values.set_all(|i| value(i as u32)); } #[inline] - fn push(&mut self, value: VarValue) { + fn push(&mut self, value: VarValue) { self.values.push(value); } @@ -107,17 +126,18 @@ where #[inline] fn update(&mut self, index: usize, op: F) where - F: FnOnce(&mut VarValue), + F: FnOnce(&mut VarValue), { self.values.update(index, op) } } -impl UnificationStore for InPlace +impl UnificationStore for InPlace where K: UnifyKey, - V: sv::VecLike>, - L: Snapshots>>, + TD: ExtraTraversalData, + V: sv::VecLike>, + L: Snapshots>>, { type Snapshot = sv::Snapshot; @@ -142,43 +162,46 @@ where } } -impl ops::Index for InPlace +impl ops::Index for InPlace where - V: sv::VecLike>, + V: sv::VecLike>, K: UnifyKey, + TD: ExtraTraversalData, { - type Output = VarValue; - fn index(&self, index: usize) -> &VarValue { + type Output = VarValue; + fn index(&self, index: usize) -> &VarValue { &self.values[index] } } #[doc(hidden)] #[derive(Copy, Clone, Debug)] -pub struct Delegate(PhantomData); +pub struct Delegate(PhantomData<(K, TD)>); -impl sv::SnapshotVecDelegate for Delegate { - type Value = VarValue; +impl> sv::SnapshotVecDelegate for Delegate { + type Value = VarValue; type Undo = (); - fn reverse(_: &mut Vec>, _: ()) {} + fn reverse(_: &mut Vec>, _: ()) {} } -impl Rollback>> for super::UnificationTableStorage { - fn reverse(&mut self, undo: sv::UndoLog>) { +impl> Rollback>> + for super::UnificationTableStorage +{ + fn reverse(&mut self, undo: sv::UndoLog>) { self.values.values.reverse(undo); } } #[cfg(feature = "persistent")] #[derive(Clone, Debug)] -pub struct Persistent { - values: DVec>, +pub struct Persistent> { + values: DVec>, } // HACK(eddyb) manual impl avoids `Default` bound on `K`. #[cfg(feature = "persistent")] -impl Default for Persistent { +impl> Default for Persistent { fn default() -> Self { Persistent { values: DVec::new(), @@ -187,9 +210,10 @@ impl Default for Persistent { } #[cfg(feature = "persistent")] -impl UnificationStoreBase for Persistent { +impl> UnificationStoreBase for Persistent { type Key = K; type Value = K::Value; + type ExtraTraversalData = TD; fn len(&self) -> usize { self.values.len() @@ -197,9 +221,12 @@ impl UnificationStoreBase for Persistent { } #[cfg(feature = "persistent")] -impl UnificationStoreMut for Persistent { +impl> UnificationStoreMut for Persistent { #[inline] - fn reset_unifications(&mut self, mut value: impl FnMut(u32) -> VarValue) { + fn reset_unifications( + &mut self, + mut value: impl FnMut(u32) -> VarValue, + ) { // Without extending dogged, there isn't obviously a more // efficient way to do this. But it's pretty dumb. Maybe // dogged needs a `map`. @@ -209,7 +236,7 @@ impl UnificationStoreMut for Persistent { } #[inline] - fn push(&mut self, value: VarValue) { + fn push(&mut self, value: VarValue) { self.values.push(value); } @@ -221,7 +248,7 @@ impl UnificationStoreMut for Persistent { #[inline] fn update(&mut self, index: usize, op: F) where - F: FnOnce(&mut VarValue), + F: FnOnce(&mut VarValue), { let p = &mut self.values[index]; op(p); @@ -229,7 +256,7 @@ impl UnificationStoreMut for Persistent { } #[cfg(feature = "persistent")] -impl UnificationStore for Persistent { +impl> UnificationStore for Persistent { type Snapshot = Self; #[inline] @@ -252,12 +279,13 @@ impl UnificationStore for Persistent { } #[cfg(feature = "persistent")] -impl ops::Index for Persistent +impl ops::Index for Persistent where K: UnifyKey, + TD: ExtraTraversalData, { - type Output = VarValue; - fn index(&self, index: usize) -> &VarValue { + type Output = VarValue; + fn index(&self, index: usize) -> &VarValue { &self.values[index] } } diff --git a/src/unify/mod.rs b/src/unify/mod.rs index a26d699..34d43e8 100644 --- a/src/unify/mod.rs +++ b/src/unify/mod.rs @@ -160,10 +160,90 @@ pub struct NoError { /// time of the algorithm under control. For more information, see /// . #[derive(PartialEq, Clone, Debug)] -pub struct VarValue { +pub struct VarValue> { parent: K, // if equal to self, this is a root value: K::Value, // value assigned (only relevant to root) rank: u32, // max depth (only relevant to root) + traversing_data: TD::Data, +} + +// Debug and Clone on these are just here because derives will generate +// bounds on these types andhaving these enable us to still not have to +// write these impls manually. + +/// Use this as [`ExtraTraversalData`] if you don't need to traverse connected +/// components directly on the [`UnificationTable`] +#[derive(Debug, Clone)] +pub struct NoExtraTraversalData; +/// Use this as [`ExtraTraversalData`] if you need to traverse connected +/// components directly on the [`UnificationTable`] +/// +/// This makes the [`UnificationTable::unioned_keys`] function available +#[derive(Debug, Clone)] +pub struct ConnectedComponentTraversal; +/// Choose an implementor of this trait to determine whether the +/// [`UnificationTable`] will store the data necessary to do connected +/// component traversal. +/// +/// This trait is largely internal and is not meant to be implemented outside +/// of this crate. +/// +/// It provides necessary functions to change the behavior of +/// the [`UnificationTable`] where it should differ between the two +/// implementations. +pub trait ExtraTraversalData: Sized + Clone + Debug { + type Data: Clone + Debug; + fn init_from_key(k: K) -> Self::Data; + fn redirect(old_root_value: &mut VarValue, sibling: K); + fn make_root(new_root_value: &mut VarValue, child: K); + fn child(var_value: &VarValue, self_key: K) -> Option; +} +impl ExtraTraversalData for NoExtraTraversalData { + type Data = (); + #[inline(always)] + fn init_from_key(_k: K) -> Self::Data {} + #[inline(always)] + fn redirect(_old_root_value: &mut VarValue, _sibling: K) {} + #[inline(always)] + fn make_root(_new_root_value: &mut VarValue, _child: K) {} + #[inline(always)] + fn child(_var_value: &VarValue, _self_key: K) -> Option { + None + } +} +impl ExtraTraversalData for ConnectedComponentTraversal { + type Data = CongruenceClosureTraversalData; + #[inline(always)] + fn init_from_key(k: K) -> Self::Data { + CongruenceClosureTraversalData { + child: k, + sibling: k, + } + } + #[inline(always)] + fn redirect(old_root_value: &mut VarValue, sibling: K) { + // Since this used to be a root, we should have + // var_value.parent = var_value.sibling = the key of this value + debug_assert_eq!( + old_root_value.parent, + old_root_value.traversing_data.sibling + ); + old_root_value.traversing_data.sibling = sibling; + } + #[inline(always)] + fn make_root(new_root_value: &mut VarValue, child: K) { + new_root_value.traversing_data.child = child; + } + #[inline(always)] + fn child(var_value: &VarValue, self_key: K) -> Option { + var_value.child(self_key) + } +} +#[doc(hidden)] +#[derive(PartialEq, Clone, Debug)] +pub struct CongruenceClosureTraversalData { + child: K, + sibling: K, } /// Table of unification keys and their values. You must define a key type K @@ -185,21 +265,22 @@ pub struct UnificationTable { values: S, } -pub type UnificationStorage = Vec>; -pub type UnificationTableStorage = UnificationTable, ()>>; +pub type UnificationStorage = Vec>; +pub type UnificationTableStorage = + UnificationTable, ()>>; /// A unification table that uses an "in-place" vector. -#[allow(type_alias_bounds)] pub type InPlaceUnificationTable< - K: UnifyKey, - V: sv::VecLike> = Vec>, - L = VecLog>>, -> = UnificationTable>; + K, + TD = NoExtraTraversalData, + V = UnificationStorage, + L = VecLog>>, +> = UnificationTable>; /// A unification table that uses a "persistent" vector. #[cfg(feature = "persistent")] -#[allow(type_alias_bounds)] -pub type PersistentUnificationTable = UnificationTable>; +pub type PersistentUnificationTable = + UnificationTable>; /// At any time, users may snapshot a unification table. The changes /// made during the snapshot may either be *committed* or *rolled back*. @@ -209,24 +290,25 @@ pub struct Snapshot { snapshot: S::Snapshot, } -impl VarValue { - fn new_var(key: K, value: K::Value) -> VarValue { - VarValue::new(key, value, 0) - } - - fn new(parent: K, value: K::Value, rank: u32) -> VarValue { +impl> VarValue { + fn new(key: K, value: K::Value) -> Self { VarValue { - parent: parent, // this is a root + parent: key, // this is a root value: value, - rank: rank, + rank: 0, + traversing_data: TD::init_from_key(key), } } - fn redirect(&mut self, to: K) { + #[inline(always)] + fn redirect(&mut self, to: K, sibling: K) { + TD::redirect(self, sibling); self.parent = to; } - fn root(&mut self, rank: u32, value: K::Value) { + #[inline(always)] + fn make_root(&mut self, rank: u32, child: K, value: K::Value) { + TD::make_root(self, child); self.rank = rank; self.value = value; } @@ -243,18 +325,29 @@ impl VarValue { } } } -impl UnificationTableStorage +impl VarValue { + fn child(&self, self_key: K) -> Option { + self.if_not_self(self.traversing_data.child, self_key) + } + + fn sibling(&self, self_key: K) -> Option { + self.if_not_self(self.traversing_data.sibling, self_key) + } +} + +impl UnificationTableStorage where K: UnifyKey, + TD: ExtraTraversalData, { /// Creates a `UnificationTable` using an external `undo_log`, allowing mutating methods to be /// called if `L` does not implement `UndoLogs` pub fn with_log<'a, L>( &'a mut self, undo_log: L, - ) -> UnificationTable, L>> + ) -> UnificationTable, L>> where - L: UndoLogs>>, + L: UndoLogs>>, { UnificationTable { values: InPlace { @@ -319,7 +412,7 @@ impl UnificationTable { pub fn new_key(&mut self, value: S::Value) -> S::Key { let len = self.values.len(); let key: S::Key = UnifyKey::from_index(len as u32); - self.values.push(VarValue::new_var(key, value)); + self.values.push(VarValue::new(key, value)); debug!("{}: created new key: {:?}", S::tag(), key); key } @@ -337,13 +430,13 @@ impl UnificationTable { self.values.reset_unifications(|i| { let key = UnifyKey::from_index(i as u32); let value = value(key); - VarValue::new_var(key, value) + VarValue::new(key, value) }); } /// Obtains the current value for a particular key. /// Not for end-users; they can use `probe_value`. - fn value(&self, key: S::Key) -> &VarValue { + fn value(&self, key: S::Key) -> &VarValue { &self.values[key.index() as usize] } @@ -383,7 +476,7 @@ impl UnificationTable { fn update_value(&mut self, key: S::Key, op: OP) where - OP: FnOnce(&mut VarValue), + OP: FnOnce(&mut VarValue), { self.values.update(key.index() as usize, op); debug!("Updated variable {:?} to {:?}", key, self.value(key)); @@ -452,15 +545,128 @@ impl UnificationTable { new_root_key: S::Key, new_value: S::Value, ) { + let sibling = >::child( + self.value(new_root_key), + new_root_key, + ) + .unwrap_or(old_root_key); self.update_value(old_root_key, |old_root_value| { - old_root_value.redirect(new_root_key); + old_root_value.redirect(new_root_key, sibling); }); self.update_value(new_root_key, |new_root_value| { - new_root_value.root(new_rank, new_value); + new_root_value.make_root(new_rank, old_root_key, new_value); }); } } +impl> UnificationTable { + /// Returns an iterator over all keys unioned with `key`. + pub fn unioned_keys(&mut self, key: K1) -> UnionedKeys + where + K1: Into, + { + let key = key.into(); + let root_key = self.uninlined_get_root_key(key); + UnionedKeys { + table: self, + stack: vec![root_key], + } + } + + /// Clears all unifications that were connected to `key`. + /// They will all become individual elements again. + /// + /// This leaves other connected components intact. + /// + /// The values of each variable are given by the closure. + pub fn reset_unifications_partial( + &mut self, + key: impl Into, + mut value: impl FnMut(S::Key) -> S::Value, + ) { + let key = key.into(); + let unioned_keys: Vec<_> = self.unioned_keys(key).collect(); + for key in unioned_keys { + self.update_value(key, |var_value| { + *var_value = VarValue::new(key, value(key)); + }); + } + } +} + +/// Iterator over keys that have been unioned together. +/// You can only obtain this from an [`UnificationTable`] that uses +/// [`ConnectedComponentTraversal`] +/// +/// Returned by the `unioned_keys` method. +pub struct UnionedKeys<'a, S> +where + S: UnificationStoreMut + 'a, + S::Key: 'a, + S::Value: 'a, +{ + table: &'a mut UnificationTable, + stack: Vec, +} + +impl<'a, S> UnionedKeys<'a, S> +where + S: UnificationStoreMut + 'a, + S::Key: 'a, + S::Value: 'a, +{ + fn var_value(&self, key: S::Key) -> &VarValue { + self.table.value(key) + } +} + +impl<'a, S: 'a> Iterator for UnionedKeys<'a, S> +where + S: UnificationStoreMut + 'a, + S::Key: 'a, + S::Value: 'a, +{ + type Item = S::Key; + + fn next(&mut self) -> Option { + let key = match self.stack.last() { + Some(k) => *k, + None => { + return None; + } + }; + + let vv = self.var_value(key); + + match vv.child(key) { + Some(child_key) => { + self.stack.push(child_key); + } + None => { + // No child, push a sibling for the current node. If + // current node has no siblings, start popping + // ancestors until we find an aunt or uncle or + // something to push. Note that we have the invariant + // that for every node N that we reach by popping + // items off of the stack, we have already visited all + // children of N. + while let Some(ancestor_key) = self.stack.pop() { + let ancestor_vv = self.var_value(ancestor_key); + match ancestor_vv.sibling(ancestor_key) { + Some(sibling) => { + self.stack.push(sibling); + break; + } + None => {} + } + } + } + } + + Some(key) + } +} + /// //////////////////////////////////////////////////////////////////////// /// Public API diff --git a/src/unify/tests.rs b/src/unify/tests.rs index 5665aba..72087c4 100644 --- a/src/unify/tests.rs +++ b/src/unify/tests.rs @@ -17,9 +17,13 @@ extern crate test; #[cfg(feature = "bench")] use self::test::Bencher; use std::cmp; +use std::collections::HashSet; #[cfg(feature = "persistent")] use unify::Persistent; -use unify::{EqUnifyValue, InPlace, InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; +use unify::{ + ConnectedComponentTraversal, EqUnifyValue, InPlace, InPlaceUnificationTable, NoError, + NoExtraTraversalData, UnifyKey, UnifyValue, +}; use unify::{UnificationStore, UnificationTable}; #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] @@ -46,10 +50,13 @@ macro_rules! all_modes { $body } - test_body::>(); + test_body::>(); + test_body::>(); #[cfg(feature = "persistent")] - test_body::>(); + test_body::>(); + #[cfg(feature = "persistent")] + test_body::>(); }; } @@ -241,6 +248,31 @@ fn even_odd() { } } +#[test] +fn even_odd_iter() { + let mut ut: InPlaceUnificationTable = + UnificationTable::new(); + let mut keys = Vec::new(); + const MAX: usize = 1 << 10; + + for i in 0..MAX { + let key = ut.new_key(()); + keys.push(key); + + if i >= 2 { + ut.union(key, keys[i - 2]); + } + } + + let even_keys: HashSet = ut.unioned_keys(keys[22]).collect(); + + assert_eq!(even_keys.len(), MAX / 2); + + for key in even_keys { + assert!((key.0 & 1) == 0); + } +} + #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] struct IntKey(u32); diff --git a/tests/external_undo_log.rs b/tests/external_undo_log.rs index 2537826..9a217c3 100644 --- a/tests/external_undo_log.rs +++ b/tests/external_undo_log.rs @@ -5,7 +5,7 @@ extern crate ena; use ena::{ snapshot_vec as sv, undo_log::{Rollback, Snapshots, UndoLogs}, - unify::{self as ut, EqUnifyValue, UnifyKey}, + unify::{self as ut, EqUnifyValue, NoExtraTraversalData, UnifyKey}, }; #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] @@ -27,12 +27,12 @@ impl UnifyKey for IntKey { impl EqUnifyValue for IntKey {} enum UndoLog { - EqRelation(sv::UndoLog>), + EqRelation(sv::UndoLog>), Values(sv::UndoLog), } -impl From>> for UndoLog { - fn from(l: sv::UndoLog>) -> Self { +impl From>> for UndoLog { + fn from(l: sv::UndoLog>) -> Self { UndoLog::EqRelation(l) } } @@ -55,7 +55,6 @@ impl Rollback for TypeVariableStorage { #[derive(Default)] struct TypeVariableStorage { values: sv::SnapshotVecStorage, - eq_relations: ut::UnificationTableStorage, }