diff --git a/crypten/cryptensor.py b/crypten/cryptensor.py index e9aef2f5..7354c921 100644 --- a/crypten/cryptensor.py +++ b/crypten/cryptensor.py @@ -5,6 +5,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools from contextlib import contextmanager import torch @@ -16,6 +17,45 @@ # list of all static functions that CrypTensors support: STATIC_FUNCTIONS = ["cat", "stack"] STATIC_FUNCTION_MAPPING = {getattr(torch, name): name for name in STATIC_FUNCTIONS} +HANDLED_FUNCTION_MAPPING = {} + +# decorator enabling overriding of torch functions: +def implements(torch_function_name): + """Register a torch function override for CrypTensor.""" + + @functools.wraps(torch_function_name) + def wrapper(func): + HANDLED_FUNCTION_MAPPING[getattr(torch, torch_function_name)] = func + HANDLED_FUNCTION_MAPPING[getattr(torch.Tensor, torch_function_name)] = func + return func + + wrapper.inherit_implements = implements + return wrapper + + +class _Base: + pass + + +class _InheritDecorators: + """ + Helper class ensuring that subclasses of CrypTensor inherit @implements decorators. + """ + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + decorator_registry = getattr(cls, "_decorator_registry", {}).copy() + cls._decorator_registry = decorator_registry + + # annotate newly decorated methods in the current subclass: + for name, obj in cls.__dict__.items(): + if getattr(obj, "inherit_implements", False) and name not in decorator_registry: + decorator_registry[name] = obj.inherit_implements + + # decorate all methods annotated in the registry: + for name, decorator in decorator_registry.items(): + if name in cls.__dict__ and getattr(getattr(cls, name), "inherit_implements", None) != decorator: + setattr(cls, name, decorator(cls.__dict__[name])) def _find_all_cryptensors(inputs): @@ -48,7 +88,7 @@ def __getattribute__(cls, name): return type.__getattribute__(cls, name) -class CrypTensor(object, metaclass=CrypTensorMetaclass): +class CrypTensor(_Base, _InheritDecorators, metaclass=CrypTensorMetaclass): """ Abstract implementation of encrypted tensor type. Every subclass of `CrypTensor` must implement the methods defined here. The actual tensor data should live in @@ -287,6 +327,7 @@ def detach(self): clone.requires_grad = False return clone + @classmethod def __torch_function__(self, func, types, args=(), kwargs=None): """Allows torch static functions to work on CrypTensors.""" if kwargs is None: @@ -296,6 +337,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # dispatch torch.{cat,stack} call on CrypTensor to CrypTen: return getattr(crypten, STATIC_FUNCTION_MAPPING[func])(*args, **kwargs) + elif func in HANDLED_FUNCTION_MAPPING: + return HANDLED_FUNCTION_MAPPING[func](*args, **kwargs) else: raise NotImplementedError( f"CrypTen does not support torch function {func}." @@ -390,6 +433,7 @@ def __getattribute__(self, name): - If this fails and function is REQUIRED, raise error b. Fetch from grad_fn.forward, ignoring AutogradContext """ + # 1. If name is in PROTECTED_ATTRIBUTES, fetch from the CrypTensor object. if name in CrypTensor.PROTECTED_ATTRIBUTES: return object.__getattribute__(self, name) @@ -470,6 +514,7 @@ def __iadd__(self, tensor): """Adds tensor to this tensor (in-place).""" return self.add_(tensor) + @implements("sub") def sub(self, tensor): """Subtracts a :attr:`tensor` from :attr:`self` tensor. The shape of :attr:`tensor` must be @@ -592,6 +637,7 @@ def clone(self): """ raise NotImplementedError("clone is not implemented") + @implements("add") def add(self, tensor): r"""Adds :attr:`tensor` to this :attr:`self`. @@ -611,6 +657,7 @@ def add(self, tensor): """ raise NotImplementedError("add is not implemented") + @implements("mul") def mul(self, tensor): r"""Element-wise multiply with a :attr:`tensor`. diff --git a/crypten/mpc/mpc.py b/crypten/mpc/mpc.py index e7300c02..bafc49d4 100644 --- a/crypten/mpc/mpc.py +++ b/crypten/mpc/mpc.py @@ -10,6 +10,7 @@ from crypten.common.tensor_types import is_tensor from crypten.common.util import torch_stack from crypten.config import cfg +from crypten.cryptensor import implements from crypten.cuda import CUDALongTensor from ..cryptensor import CrypTensor @@ -320,6 +321,11 @@ def div(self, y): "conv_transpose1d", "conv_transpose2d", ] +TORCH_OVERRIDE_BINARY_FUNCTION = [ + "add", + "sub", + "mul", +] def _add_unary_passthrough_function(name): @@ -333,12 +339,22 @@ def unary_wrapper_function(self, *args, **kwargs): def _add_binary_passthrough_function(name): def binary_wrapper_function(self, value, *args, **kwargs): + func_name = name + if torch.is_tensor(self) and isinstance(value, MPCTensor): + self, value = value, self # swap order of arguments + func_name = f"__r{name}__" # invoke __radd__, __rsub__, etc. + result = self.shallow_copy() if isinstance(value, MPCTensor): value = value._tensor - result._tensor = getattr(result._tensor, name)(value, *args, **kwargs) + result._tensor = getattr(result._tensor, func_name)(value, *args, **kwargs) return result + # register as torch function that can be overridden: + if name in TORCH_OVERRIDE_BINARY_FUNCTION: + implements(name)(binary_wrapper_function) + + # register function into MPCTensor: setattr(MPCTensor, name, binary_wrapper_function)