Skip to content

Commit

Permalink
Implement clone() (#429)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tfogal and pre-commit-ci[bot] authored May 25, 2024
1 parent 4c9a765 commit 7d6e540
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
49 changes: 49 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,55 @@ def forward(self, x):
assert "t_fc2_weight" in sig.parameters


@requiresCUDA
def test_clone():
def foo(a):
return a.clone()

jfoo = thunder.jit(foo)
for shp in ((3, 5), [7], (8, 6, 4)):
for dev in (torch.device("cpu"), torch.device("cuda:0")):
for dt in (torch.float32, torch.float16, torch.bfloat16):
# there are issues with layouts other than strided; see
# test_clone_sparse_coo.
lout = torch.strided
b = jfoo(torch.randn(shp, device=dev, layout=lout, dtype=dt))
assert b.dtype == dt
assert b.layout == lout
assert b.device == dev
assert b.shape == torch.Size(shp)


# Separate out the sparse test because creating a sparse tensor is tricky.
def test_clone_sparse_coo():
def foo(a):
return a.clone()

jfoo = thunder.jit(foo)
shp = (3, 5)
dev = torch.device("cpu")
dt = torch.float32
# randn(layout=torch.sparse_coo, ...) will throw an exception deep in
# PyTorch, so we use to_sparse() from a dense tensor to get a sparse one.
b = jfoo(torch.randn(shp, device=dev, dtype=dt).to_sparse())
assert b.dtype == dt
assert b.layout == torch.sparse_coo
assert b.device == dev
assert b.shape == torch.Size(shp)


@pytest.mark.xfail(reason="we improperly use an alias")
def test_clone_alias():
def foo(a):
b = a.clone()
b[0] = 42

jfoo = thunder.jit(foo)
arg = torch.tensor([7, 19])
jfoo(arg)
assert arg[0] == 7


@instantiate(dtypes=(thunder.float32,))
def test_default_method(executor, device: str, dtype: dtypes.dtype):
# This test ensures that when no language context is given, it will fallback to the default implementation.
Expand Down
31 changes: 31 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,37 @@ def amin(a, /, dim=None, keepdim: bool = False):
)


# Clone is unique in that it's not registered as a symbol; as such we add it to
# the appropriate maps manually, instead of through the @torchsymbol decorator.
# This means that clone will not appear in the trace; instead, this basically
# just gets inlined into the code.
def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy:
"""
Produce a copy of a tensor as a distinct new tensor.
Note: the implementation currently creates an alias instead of a copy.
"""
# Our implementation currently does not introduce a copy, and so nothing
# except preserve_format is feasible to support.
# If you're hitting this you could try commenting this check out; if your
# model does not actually rely on specified memory formats then it should
# be fine.
if memory_format is not torch.preserve_format:
raise NotImplementedError("only preserve_format is currently supported")
# This implementation just creates an alias instead of a copy. This may
# introduce problems; such problems would be fixable when we get to adding an
# SSA pass, but for now we do not expect that aliasing the tensor will
# introduce many problems.
return a


# Because we do not use @torchsymbol, we need to manually register the
# implementation.
_torch_to_thunder_function_map[torch.clone] = clone
_torch_to_thunder_function_map[torch.Tensor.clone] = clone
register_method("clone", clone)


@torchsymbol(torch.mean, is_method=True)
def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None):
dtype = dtype if dtype is not None else a.dtype
Expand Down

0 comments on commit 7d6e540

Please sign in to comment.