Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes issue 403 for select functions: add, sub, mul #413

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion crypten/cryptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import functools
from contextlib import contextmanager

import torch
Expand All @@ -16,6 +17,45 @@
# list of all static functions that CrypTensors support:
STATIC_FUNCTIONS = ["cat", "stack"]
STATIC_FUNCTION_MAPPING = {getattr(torch, name): name for name in STATIC_FUNCTIONS}
HANDLED_FUNCTION_MAPPING = {}

# decorator enabling overriding of torch functions:
def implements(torch_function_name):
"""Register a torch function override for CrypTensor."""

@functools.wraps(torch_function_name)
def wrapper(func):
HANDLED_FUNCTION_MAPPING[getattr(torch, torch_function_name)] = func
HANDLED_FUNCTION_MAPPING[getattr(torch.Tensor, torch_function_name)] = func
return func

wrapper.inherit_implements = implements
return wrapper


class _Base:
pass


class _InheritDecorators:
"""
Helper class ensuring that subclasses of CrypTensor inherit @implements decorators.
"""

def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
decorator_registry = getattr(cls, "_decorator_registry", {}).copy()
cls._decorator_registry = decorator_registry

# annotate newly decorated methods in the current subclass:
for name, obj in cls.__dict__.items():
if getattr(obj, "inherit_implements", False) and name not in decorator_registry:
decorator_registry[name] = obj.inherit_implements

# decorate all methods annotated in the registry:
for name, decorator in decorator_registry.items():
if name in cls.__dict__ and getattr(getattr(cls, name), "inherit_implements", None) != decorator:
setattr(cls, name, decorator(cls.__dict__[name]))


def _find_all_cryptensors(inputs):
Expand Down Expand Up @@ -48,7 +88,7 @@ def __getattribute__(cls, name):
return type.__getattribute__(cls, name)


class CrypTensor(object, metaclass=CrypTensorMetaclass):
class CrypTensor(_Base, _InheritDecorators, metaclass=CrypTensorMetaclass):
"""
Abstract implementation of encrypted tensor type. Every subclass of `CrypTensor`
must implement the methods defined here. The actual tensor data should live in
Expand Down Expand Up @@ -287,6 +327,7 @@ def detach(self):
clone.requires_grad = False
return clone

@classmethod
def __torch_function__(self, func, types, args=(), kwargs=None):
"""Allows torch static functions to work on CrypTensors."""
if kwargs is None:
Expand All @@ -296,6 +337,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):

# dispatch torch.{cat,stack} call on CrypTensor to CrypTen:
return getattr(crypten, STATIC_FUNCTION_MAPPING[func])(*args, **kwargs)
elif func in HANDLED_FUNCTION_MAPPING:
return HANDLED_FUNCTION_MAPPING[func](*args, **kwargs)
else:
raise NotImplementedError(
f"CrypTen does not support torch function {func}."
Expand Down Expand Up @@ -390,6 +433,7 @@ def __getattribute__(self, name):
- If this fails and function is REQUIRED, raise error
b. Fetch from grad_fn.forward, ignoring AutogradContext
"""

# 1. If name is in PROTECTED_ATTRIBUTES, fetch from the CrypTensor object.
if name in CrypTensor.PROTECTED_ATTRIBUTES:
return object.__getattribute__(self, name)
Expand Down Expand Up @@ -470,6 +514,7 @@ def __iadd__(self, tensor):
"""Adds tensor to this tensor (in-place)."""
return self.add_(tensor)

@implements("sub")
def sub(self, tensor):
"""Subtracts a :attr:`tensor` from :attr:`self` tensor.
The shape of :attr:`tensor` must be
Expand Down Expand Up @@ -592,6 +637,7 @@ def clone(self):
"""
raise NotImplementedError("clone is not implemented")

@implements("add")
def add(self, tensor):
r"""Adds :attr:`tensor` to this :attr:`self`.

Expand All @@ -611,6 +657,7 @@ def add(self, tensor):
"""
raise NotImplementedError("add is not implemented")

@implements("mul")
def mul(self, tensor):
r"""Element-wise multiply with a :attr:`tensor`.

Expand Down
18 changes: 17 additions & 1 deletion crypten/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from crypten.common.tensor_types import is_tensor
from crypten.common.util import torch_stack
from crypten.config import cfg
from crypten.cryptensor import implements
from crypten.cuda import CUDALongTensor

from ..cryptensor import CrypTensor
Expand Down Expand Up @@ -320,6 +321,11 @@ def div(self, y):
"conv_transpose1d",
"conv_transpose2d",
]
TORCH_OVERRIDE_BINARY_FUNCTION = [
"add",
"sub",
"mul",
]


def _add_unary_passthrough_function(name):
Expand All @@ -333,12 +339,22 @@ def unary_wrapper_function(self, *args, **kwargs):

def _add_binary_passthrough_function(name):
def binary_wrapper_function(self, value, *args, **kwargs):
func_name = name
if torch.is_tensor(self) and isinstance(value, MPCTensor):
self, value = value, self # swap order of arguments
func_name = f"__r{name}__" # invoke __radd__, __rsub__, etc.

result = self.shallow_copy()
if isinstance(value, MPCTensor):
value = value._tensor
result._tensor = getattr(result._tensor, name)(value, *args, **kwargs)
result._tensor = getattr(result._tensor, func_name)(value, *args, **kwargs)
return result

# register as torch function that can be overridden:
if name in TORCH_OVERRIDE_BINARY_FUNCTION:
implements(name)(binary_wrapper_function)

# register function into MPCTensor:
setattr(MPCTensor, name, binary_wrapper_function)


Expand Down