Skip to content

Commit

Permalink
Allow export of non fully replicated tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
EiffL committed May 11, 2021
1 parent d70c4d6 commit 70a13d3
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions mesh_tensorflow/hvd_simd_mesh_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,14 @@ def export_to_tf_tensor(self, x, laid_out_x):
"""
tensor_layout = self.tensor_layout(x.shape)
if not tensor_layout.is_fully_replicated:
raise NotImplementedError(
"SimdMeshImpl only supports export_to_tf_tensor of fully-replicated "
"Tensors. Try reshaping to new dimension names. "
" x.shape = %s tensor_layout=%s"
% (x.shape, tensor_layout))
print("Warning: Exported tensor is not fully replicated"
" x.shape = %s tensor_layout=%s"
% (x.shape, tensor_layout))
# raise NotImplementedError(
# "SimdMeshImpl only supports export_to_tf_tensor of fully-replicated "
# "Tensors. Try reshaping to new dimension names. "
# " x.shape = %s tensor_layout=%s"
# % (x.shape, tensor_layout))
return laid_out_x.one_slice

def import_tf_tensor(self, x, tf_x):
Expand Down

0 comments on commit 70a13d3

Please sign in to comment.