diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 371b77df22a..7185cb5316e 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -18,7 +18,12 @@ def compare_states( - from_state: SyncState, to_state: SyncState, include_ignored: bool = False + from_state: SyncState, + to_state: SyncState, + include_ignored: bool = False, + include_same: bool = False, + filter_by_email: str | None = None, + filter_by_type: str | type | None = None, ) -> NodeDiff: # NodeDiff if ( @@ -42,11 +47,28 @@ def compare_states( high_state=high_state, direction=direction, include_ignored=include_ignored, + include_same=include_same, + filter_by_email=filter_by_email, + filter_by_type=filter_by_type, ) -def compare_clients(low_client: SyftClient, high_client: SyftClient) -> NodeDiff: - return compare_states(low_client.get_sync_state(), high_client.get_sync_state()) +def compare_clients( + from_client: SyftClient, + to_client: SyftClient, + include_ignored: bool = False, + include_same: bool = False, + filter_by_email: str | None = None, + filter_by_type: type | None = None, +) -> NodeDiff: + return compare_states( + from_client.get_sync_state(), + to_client.get_sync_state(), + include_ignored=include_ignored, + include_same=include_same, + filter_by_email=filter_by_email, + filter_by_type=filter_by_type, + ) def get_user_input_for_resolve() -> SyncDecision: diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 4c24487edb4..f40e86306c3 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1054,10 +1054,10 @@ def from_batch(self, batch: ObjectDiffBatch) -> Any: if isinstance(user, UserView): return user.email return None - elif self == FilterProperty.BATCH_TYPE: - return batch.root_diff.obj_type + elif self == FilterProperty.TYPE: + return batch.root_diff.obj_type.__name__.lower() elif self == FilterProperty.STATUS: - return batch.status + return batch.status.lower() elif self == FilterProperty.IGNORED: return batch.is_ignored else: @@ -1069,7 +1069,7 @@ class NodeDiffFilter: """ Filter to apply to a NodeDiff object to determine if it should be included in a batch. - Tests for `property op value` , where + Checks for `property op value` , where property: FilterProperty - property to filter on value: Any - value to compare against op: callable[[Any, Any], bool] - comparison operator. Default is `operator.eq` @@ -1082,28 +1082,21 @@ class NodeDiffFilter: op: Callable[[Any, Any], bool] = operator.eq def __call__(self, batch: ObjectDiffBatch) -> bool: + filter_value = self.filter_value + if isinstance(filter_value, str): + filter_value = filter_value.lower() + try: p = self.filter_property.from_batch(batch) if self.op == operator.contains: # Contains check has reversed arg order: check if p in self.filter_value - return p in self.filter_value + return p in filter_value else: - return self.op(p, self.filter_value) + return self.op(p, filter_value) except Exception as e: + # By default, exclude the batch if there is an error logger.debug(f"Error filtering batch {batch} with {self}: {e}") - return True - - def __hash__(self) -> int: - return hash(self.filter_property) + hash(self.filter_value) + hash(self.op) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, NodeDiffFilter): return False - return ( - self.filter_property == other.filter_property - and self.filter_value == other.filter_value - and self.op == other.op - ) class NodeDiff(SyftObject): @@ -1160,6 +1153,8 @@ def from_sync_state( direction: SyncDirection, include_ignored: bool = False, include_same: bool = False, + filter_by_email: str | None = None, + filter_by_type: type | None = None, _include_node_status: bool = False, ) -> "NodeDiff": obj_uid_to_diff = {} @@ -1212,31 +1207,31 @@ def from_sync_state( previously_ignored_batches = low_state.ignored_batches NodeDiff.apply_previous_ignore_state(all_batches, previously_ignored_batches) - filters = [] - if not include_ignored: - filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)) - if not include_same: - filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)) - - batches = all_batches - for f in filters: - batches = [b for b in batches if f(b)] - - return cls( + res = cls( low_node_uid=low_state.node_uid, high_node_uid=high_state.node_uid, user_verify_key_low=low_state.syft_client_verify_key, user_verify_key_high=high_state.syft_client_verify_key, obj_uid_to_diff=obj_uid_to_diff, obj_dependencies=obj_dependencies, - batches=batches, + batches=all_batches, all_batches=all_batches, low_state=low_state, high_state=high_state, direction=direction, - filters=filters, + filters=[], + ) + + res._filter( + user_email=filter_by_email, + obj_type=filter_by_type, + include_ignored=include_ignored, + include_same=include_same, + inplace=True, ) + return res + @staticmethod def apply_previous_ignore_state( batches: list[ObjectDiffBatch], previously_ignored_batches: dict[UID, int] @@ -1414,65 +1409,65 @@ def hierarchies( def is_same(self) -> bool: return all(object_diff.status == "SAME" for object_diff in self.diffs) - def _apply_filters(self, filters: list[NodeDiffFilter]) -> Self: + def _apply_filters( + self, filters: list[NodeDiffFilter], inplace: bool = True + ) -> Self: """ Apply filters to the NodeDiff object and return a new NodeDiff object """ batches = self.all_batches for filter in filters: batches = [b for b in batches if filter(b)] - return NodeDiff( - low_node_uid=self.low_node_uid, - high_node_uid=self.high_node_uid, - user_verify_key_low=self.user_verify_key_low, - user_verify_key_high=self.user_verify_key_high, - obj_uid_to_diff=self.obj_uid_to_diff, - obj_dependencies=self.obj_dependencies, - batches=batches, - all_batches=self.all_batches, - low_state=self.low_state, - high_state=self.high_state, - direction=self.direction, - filters=filters, - ) - def reset_filters( + if inplace: + self.filters = filters + self.batches = batches + return self + else: + return NodeDiff( + low_node_uid=self.low_node_uid, + high_node_uid=self.high_node_uid, + user_verify_key_low=self.user_verify_key_low, + user_verify_key_high=self.user_verify_key_high, + obj_uid_to_diff=self.obj_uid_to_diff, + obj_dependencies=self.obj_dependencies, + batches=batches, + all_batches=self.all_batches, + low_state=self.low_state, + high_state=self.high_state, + direction=self.direction, + filters=filters, + ) + + def _filter( self, + user_email: str | None = None, + obj_type: str | type | None = None, include_ignored: bool = False, include_same: bool = False, + inplace: bool = True, ) -> Self: - filters = [] - if not include_ignored: - filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)) - if not include_same: - filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)) - return self._apply_filters(filters) - - def filter( - self, - user: str | None = None, - obj_type: type | None = None, - ) -> Self: - current_filters = self.filters new_filters = [] - if user is not None: - new_filters.append(NodeDiffFilter(FilterProperty.USER, user)) + if user_email is not None: + new_filters.append( + NodeDiffFilter(FilterProperty.USER, user_email, operator.eq) + ) if obj_type is not None: - new_filters.append(NodeDiffFilter(FilterProperty.TYPE, obj_type)) - - if len(new_filters) == 0: - return self - - new_filter_properties = {f.filter_property for f in new_filters} - # Only add filters that are not in the new filters - # - remove duplicate filters - # - overwrite filters with the same property but different value - # (example: cannot filter on 2 different users) - for current_filter in current_filters: - if current_filter.filter_property not in new_filter_properties: - new_filters.append(current_filter) + if isinstance(obj_type, type): + obj_type = obj_type.__name__ + new_filters.append( + NodeDiffFilter(FilterProperty.TYPE, obj_type, operator.eq) + ) + if not include_ignored: + new_filters.append( + NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne) + ) + if not include_same: + new_filters.append( + NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne) + ) - return self._apply_filters(new_filters) + return self._apply_filters(new_filters, inplace=inplace) class SyncInstruction(SyftObject):