diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 1473fd5f995b..18a64629f258 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1068,7 +1068,7 @@ def test_backward_optimization_barrier(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad]) self.assertIn( - '%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)', + '%opt-barrier.38 = (f32[1,64]{1,0}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{1,0}, f32[1]{0}, f32[2,64]{1,0}) %tuple.37)', hlo) def test_mark_shard_scalar(self): diff --git a/test/test_operations.py b/test/test_operations.py index 1af928e6a471..4a1338dc30e8 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -58,7 +58,7 @@ DeviceSupport = collections.namedtuple('DeviceSupport', ['num_devices']) XLA_DISABLE_FUNCTIONALIZATION = bool( - os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', False)) + int(os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', '0'))) def _is_on_tpu(): @@ -2783,6 +2783,66 @@ def test_unsafe_buffer_pointer(self): buf_ptr_3 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_3) self.assertGreaterEqual(buf_ptr_3, 0) + def test_consistent_strides(self): + # Tests whether the `is_contiguous()` method is consisten with the tensor's stride. + # In other words, if `is_contiguous()` is true, the tensor's stride should reflect + # in a contiguous storage. + + def stride_is_contiguous(tensor): + # Order the sizes and strides tuple list in ascending stride order, so that the + # first element corresponds to the smallest stride. + sizes_and_strides = list( + sorted(zip(tensor.shape, tensor.stride()), key=lambda t: t[1])) + + # A contiguous tensor's smallest stride should be 1. + if sizes_and_strides[0][1] != 1: + return False + + # Check whether the next larger stride `stride[i + 1]` is equal the current + # one `stride[i]` multiplied by the current size `size[i]`. + for i, (size, stride) in enumerate(sizes_and_strides[:-1]): + if stride[i + 1] != stride[i] * size[i]: + return False + + return True + + def assert_strides_consistent(tensor): + self.assertEquals(tensor.is_contiguous(), stride_is_contiguous(tensor)) + + # Obviously contiguous, since it was created with random. + a = torch.rand(10).to(xm.xla_device()) + assert_strides_consistent(a) + + # Not contiguous, since we are skipping every other element. + b = a[::2] + assert_strides_consistent(b) + + # Still not contiguous, since 'b' is not contiguous. + c = b[1:] + assert_strides_consistent(c) + + def test_contiguity_on_different_memory_format(self): + # Create contiguous strided tensor. + a = torch.rand(2, 3, 4, 5).to(xm.xla_device()) + self.assertTrue(a.is_contiguous()) + # When functionalization is disabled, we fallback to the old behavior, where + # `is_contiguous()` calls always returns True. + self.assertEquals( + a.is_contiguous(memory_format=torch.channels_last), + XLA_DISABLE_FUNCTIONALIZATION) + + # Make `a` contiguous in torch.channels_last memory format. + # + # This should, in theory, be a no-op, since we can't really change the strides + # of XLA tensors. However, `contiguous` is a composite operation that checks the + # tensor's metadata. Therefore, it shall clone the tensor whenever its strides + # do not conform to the given memory format. + b = a.contiguous(memory_format=torch.channels_last) + # When functionalization is disabled, we fallback to the old behavior, where + # `is_contiguous()` calls always returns True. + self.assertEquals(b.is_contiguous(), XLA_DISABLE_FUNCTIONALIZATION) + self.assertTrue(b.is_contiguous(memory_format=torch.channels_last)) + class TestDLPack(parameterized.TestCase): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d355d6c378f8..2ea7e4e6a871 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1227,11 +1227,28 @@ at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self, } at::Tensor XLANativeFunctions::clone( - const at::Tensor& self, - std::optional /* memory_format */) { + const at::Tensor& self, std::optional memory_format) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( + + at::Tensor out = bridge::AtenFromXlaTensor( tensor_methods::clone(bridge::GetXlaTensor(self))); + + if (!runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) { + at::Tensor ref; + if (memory_format.has_value() && + *memory_format != at::MemoryFormat::Preserve) { + // We need to run the meta function as reference, for setting the correct + // strides to the output tensor. + at::Tensor ref_self = self.to(at::kMeta); + ref = ref_self.clone(memory_format); + } else { + ref = self; + } + out.unsafeGetTensorImpl()->set_sizes_and_strides(ref.sym_sizes(), + ref.sym_strides()); + } + + return out; } at::Tensor XLANativeFunctions::constant_pad_nd(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 4e69127ff816..22539532663d 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -173,9 +173,16 @@ int64_t XLATensorImpl::numel_custom() const { } bool XLATensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { + // If functionalization is disabled, the tensors' metadata aren't being + // updated w.r.t. the output of meta functions. Therefore, we fallback to the + // old behavior returning true, always. + if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) { + return true; + } + // Storage is always contiguous, but the tensor metadata is_contiguous_ might // be false due to the update in the functionalization layer.. - return true; + return c10::TensorImpl::is_contiguous_custom(memory_format); } void XLATensorImpl::SetupSizeProperties() { diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 9008e03dbd91..d6eb4d1a65b9 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -4,7 +4,7 @@ """ -from typing import Callable, TypeVar +from typing import Callable, Tuple, TypeVar import torch from torch.utils._pytree import tree_map, tree_iter @@ -15,10 +15,10 @@ def scan( - fn: Callable[[Carry, X], tuple[Carry, Y]], + fn: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry, xs: X, -) -> tuple[Carry, Y]: +) -> Tuple[Carry, Y]: """Apply a function over leading dimension of tensors while carrying along state. This is similar to the JAX `jax.lax.scan` function found in [1].