Skip to content

Commit

Permalink
Change current crypten.load and crypten.save functions to load_from_p…
Browse files Browse the repository at this point in the history
…arty and save_from_party respectively (facebookresearch#202)

Summary:
Pull Request resolved: fairinternal/CrypTen#202

Change current crypten.load and crypten.save functions to load_from_party and save_from_party respectively.

Reviewed By: knottb

Differential Revision: D21026301

fbshipit-source-id: 7ed8a8b483432caa826198867d22a54542393178
  • Loading branch information
Shobha Venkataraman authored and facebook-github-bot committed Apr 20, 2020
1 parent 2825f47 commit d6f221b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
46 changes: 41 additions & 5 deletions crypten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _setup_przs():
comm.get().global_generator.manual_seed(global_seed.item())


def load(
def load_from_party(
f,
preloaded=None,
encrypted=False,
Expand All @@ -211,7 +211,7 @@ def load(
**kwargs
):
"""
Loads an object saved with `torch.save()` or `crypten.save()`.
Loads an object saved with `torch.save()` or `crypten.save_from_party()`.
Args:
f: a file-like object (has to implement `read()`, `readline()`,
Expand All @@ -231,7 +231,7 @@ def load(
will broadcast it to the other parties
load_closure: Custom load function that matches the interface of `torch.load`,
to be used when the tensor is saved with a custom save function in
`crypten.save`. Additional kwargs are passed on to the closure.
`crypten.save_from_party`. Additional kwargs are passed on to the closure.
"""
if dummy_model is not None:
warnings.warn(
Expand Down Expand Up @@ -276,7 +276,29 @@ def load(
return result


def save(obj, f, src=0, save_closure=torch.save, **kwargs):
def load(
f,
preloaded=None,
encrypted=False,
dummy_model=None,
src=0,
load_closure=torch.load,
**kwargs
):
"""
Loads an object saved with `torch.save()` or `crypten.save_from_party()`.
Note: this function is deprecated; please use load_from_party instead.
"""
warnings.warn(
"The current 'load' function is deprecated, and will be removed soon. "
"To continue using current 'load' functionality, please use the "
"'load_from_party' function instead.",
DeprecationWarning,
)
load_from_party(f, preloaded, encrypted, dummy_model, src, load_closure, **kwargs)


def save_from_party(obj, f, src=0, save_closure=torch.save, **kwargs):
"""
Saves a CrypTensor or PyTorch tensor to a file.
Expand All @@ -287,7 +309,7 @@ def save(obj, f, src=0, save_closure=torch.save, **kwargs):
src: The source party that writes data to the specified file.
save_closure: Custom save function that matches the interface of `torch.save`,
to be used when the tensor is saved with a custom load function in
`crypten.load`. Additional kwargs are passed on to the closure.
`crypten.load_from_party`. Additional kwargs are passed on to the closure.
"""
if is_encrypted_tensor(obj):
raise NotImplementedError("Saving encrypted tensors is not yet supported")
Expand All @@ -304,6 +326,20 @@ def save(obj, f, src=0, save_closure=torch.save, **kwargs):
comm.get().barrier()


def save(obj, f, src=0, save_closure=torch.save, **kwargs):
"""
Saves a CrypTensor or PyTorch tensor to a file.
Note: this function is deprecated, please use save_from_party instead
"""
warnings.warn(
"The current 'save' function is deprecated, and will be removed soon. "
"To continue using current 'save' functionality, please use the "
"'save_from_party' function instead.",
DeprecationWarning,
)
save_from_party(obj, f, src, save_closure, **kwargs)


def where(condition, input, other):
"""
Return a tensor of elements selected from either `input` or `other`, depending
Expand Down
2 changes: 1 addition & 1 deletion examples/tfe_benchmarks/tfe_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def validate(val_loader, model, criterion, print_freq=10, flatten=False):
def save_checkpoint(
state, is_best, filename="checkpoint.pth.tar", model_best="model_best.pth.tar"
):
# TODO: use crypten.save() in future.
# TODO: use crypten.save_from_party() in future.
rank = comm.get().get_rank()
# only save for process rank = 0
if rank == 0:
Expand Down
11 changes: 6 additions & 5 deletions test/test_crypten.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,16 @@ def custom_save_function(obj, f):
test_load_fn = all_test_load_fns[i]
complete_file = filename + all_file_completions[i]
for src in range(comm.get().get_world_size()):
crypten.save(
crypten.save_from_party(
tensor, complete_file, src=src, save_closure=save_closure
)

# the following line will throw an error if an object saved with
# torch.save is attempted to be loaded with np.load
if self.rank == src:
test_load_fn(complete_file)

encrypted_load = crypten.load(
encrypted_load = crypten.load_from_party(
complete_file, src=src, load_closure=load_closure
)

Expand All @@ -187,7 +188,7 @@ def custom_save_function(obj, f):

# test for invalid load_closure
with self.assertRaises(TypeError):
crypten.load(
crypten.load_from_party(
complete_file, src=src, load_closure=(lambda f: None)
)

Expand All @@ -205,9 +206,9 @@ def test_save_load_module(self):

filename = tempfile.NamedTemporaryFile(delete=True).name
for src in range(comm.get().get_world_size()):
crypten.save(test_model, filename, src=src)
crypten.save_from_party(test_model, filename, src=src)

result = crypten.load(filename, src=src)
result = crypten.load_from_party(filename, src=src)
if src == rank:
for param in result.parameters(recurse=True):
self.assertTrue(
Expand Down

0 comments on commit d6f221b

Please sign in to comment.