Skip to content

Commit

Permalink
Add load/save closure functions (facebookresearch#201)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: fairinternal/CrypTen#201

This diff adds support for using custom load and save functions in `crypten.load` and `crypten.save`. The custom load/save functions need to have the same interface as `torch.load/torch.save`.

Note that this does not yet change the name of the function to `load_from_party` and `save_from_party` as we discussed in the design document. I plan to do that in the next diff.

Reviewed By: knottb

Differential Revision: D21025272

fbshipit-source-id: 8ae99e7b4ef0aa0ecfe99e84f4b23417f355bdaf
  • Loading branch information
Shobha Venkataraman authored and facebook-github-bot committed Apr 20, 2020
1 parent 2389544 commit 2825f47
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
25 changes: 21 additions & 4 deletions crypten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,15 @@ def _setup_przs():
comm.get().global_generator.manual_seed(global_seed.item())


def load(f, preloaded=None, encrypted=False, dummy_model=None, src=0, **kwargs):
def load(
f,
preloaded=None,
encrypted=False,
dummy_model=None,
src=0,
load_closure=torch.load,
**kwargs
):
"""
Loads an object saved with `torch.save()` or `crypten.save()`.
Expand All @@ -221,6 +229,9 @@ def load(f, preloaded=None, encrypted=False, dummy_model=None, src=0, **kwargs):
party will attempt to read in the specified file. If `src` is
specified, the source party will read the tensor from `f` and it
will broadcast it to the other parties
load_closure: Custom load function that matches the interface of `torch.load`,
to be used when the tensor is saved with a custom save function in
`crypten.save`. Additional kwargs are passed on to the closure.
"""
if dummy_model is not None:
warnings.warn(
Expand All @@ -236,7 +247,7 @@ def load(f, preloaded=None, encrypted=False, dummy_model=None, src=0, **kwargs):

# source party
if comm.get().get_rank() == src:
result = preloaded if preloaded else torch.load(f, **kwargs)
result = preloaded if preloaded else load_closure(f, **kwargs)

# Zero out the tensors / modules to hide loaded data from broadcast
if torch.is_tensor(result):
Expand All @@ -245,13 +256,16 @@ def load(f, preloaded=None, encrypted=False, dummy_model=None, src=0, **kwargs):
result_zeros = copy.deepcopy(result)
result_zeros.set_all_parameters(0)
else:
result = comm.get().broadcast_obj(-1, src)
raise TypeError("Unrecognized load type %s" % type(result))

comm.get().broadcast_obj(result_zeros, src)

# Non-source party
else:
result = comm.get().broadcast_obj(None, src)
if isinstance(result, int) and result == -1:
raise TypeError("Unrecognized load type from src party")

if torch.is_tensor(result):
result = crypten.cryptensor(result, src=src)
Expand All @@ -262,7 +276,7 @@ def load(f, preloaded=None, encrypted=False, dummy_model=None, src=0, **kwargs):
return result


def save(obj, f, src=0, **kwargs):
def save(obj, f, src=0, save_closure=torch.save, **kwargs):
"""
Saves a CrypTensor or PyTorch tensor to a file.
Expand All @@ -271,6 +285,9 @@ def save(obj, f, src=0, **kwargs):
f: a file-like object (has to implement `read()`, `readline()`,
`tell()`, and `seek()`), or a string containing a file name
src: The source party that writes data to the specified file.
save_closure: Custom save function that matches the interface of `torch.save`,
to be used when the tensor is saved with a custom load function in
`crypten.load`. Additional kwargs are passed on to the closure.
"""
if is_encrypted_tensor(obj):
raise NotImplementedError("Saving encrypted tensors is not yet supported")
Expand All @@ -281,7 +298,7 @@ def save(obj, f, src=0, **kwargs):
), "Save failed: src must be an integer in [0, world_size)"

if comm.get().get_rank() == src:
torch.save(obj, f, **kwargs)
save_closure(obj, f, **kwargs)

# Implement barrier to avoid race conditions that require file to exist
comm.get().barrier()
Expand Down
57 changes: 46 additions & 11 deletions test/test_crypten.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,26 +135,61 @@ def test_cryptensor_instantiation(self):
def test_save_load(self):
"""Test that crypten.save and crypten.load properly save and load tensors"""
import tempfile
import numpy as np

def custom_load_function(f):
np_arr = np.load(f)
tensor = torch.from_numpy(np_arr)
return tensor

def custom_save_function(obj, f):
np_arr = obj.numpy()
np.save(f, np_arr)

comm = crypten.communicator
filename = tempfile.NamedTemporaryFile(delete=True).name
all_save_fns = [torch.save, custom_save_function]
all_load_fns = [torch.load, custom_load_function]
all_file_completions = [".pth", ".npy"]
all_test_load_fns = [torch.load, np.load]
for dimensions in range(1, 5):
# Create tensors with different sizes on each rank
size = [self.rank + 1] * dimensions
size = tuple(size)
tensor = torch.randn(size=size)

for src in range(comm.get().get_world_size()):
crypten.save(tensor, filename, src=src)
encrypted_load = crypten.load(filename, src=src)

reference_size = tuple([src + 1] * dimensions)
self.assertEqual(encrypted_load.size(), reference_size)

size_out = [src + 1] * dimensions
reference = tensor if self.rank == src else torch.empty(size=size_out)
comm.get().broadcast(reference, src=src)
self._check(encrypted_load, reference, "crypten.load() failed")
for i, save_closure in enumerate(all_save_fns):
load_closure = all_load_fns[i]
test_load_fn = all_test_load_fns[i]
complete_file = filename + all_file_completions[i]
for src in range(comm.get().get_world_size()):
crypten.save(
tensor, complete_file, src=src, save_closure=save_closure
)
# the following line will throw an error if an object saved with
# torch.save is attempted to be loaded with np.load
if self.rank == src:
test_load_fn(complete_file)

encrypted_load = crypten.load(
complete_file, src=src, load_closure=load_closure
)

reference_size = tuple([src + 1] * dimensions)
self.assertEqual(encrypted_load.size(), reference_size)

size_out = [src + 1] * dimensions
reference = (
tensor if self.rank == src else torch.empty(size=size_out)
)
comm.get().broadcast(reference, src=src)
self._check(encrypted_load, reference, "crypten.load() failed")

# test for invalid load_closure
with self.assertRaises(TypeError):
crypten.load(
complete_file, src=src, load_closure=(lambda f: None)
)

def test_save_load_module(self):
"""Test that crypten.save and crypten.load properly save and load modules"""
Expand Down

0 comments on commit 2825f47

Please sign in to comment.