Skip to content

Commit

Permalink
pass property arg into to_tensor (PaddlePaddle#57502)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql authored and iosmers committed Sep 21, 2023
1 parent 6fbc3a4 commit b9057a5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def shard_tensor(
"""
# 1. create dense tensor
# `paddle.to_tensor` supports both dynamic and static mode
tensor = paddle.to_tensor(data)
tensor = paddle.to_tensor(
data, dtype=dtype, place=place, stop_gradient=stop_gradient
)

# 2. create dist tensor
assert len(dist_attr.dims_mapping) == len(
Expand Down
28 changes: 27 additions & 1 deletion test/auto_parallel/test_shard_tensor_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def setUp(self):
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)

def test_dynamic(self):
def test_dynamic_mode_basic(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=[None, None, None]
)
Expand All @@ -77,6 +77,32 @@ def test_dynamic(self):
self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh"))
self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping"))

def test_dynamic_mode_property_change(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=[None, None, None]
)

x = np.random.random([4, 1024, 512]).astype("float32")
input = paddle.to_tensor(
x, dtype="float32", place='cpu', stop_gradient=False
)
d_tensor = dist.shard_tensor(
input,
dtype="float64",
place='gpu:0',
stop_gradient=True,
dist_attr=dist_attr,
)

self.assertEqual(d_tensor.dtype, paddle.float64)
self.assertTrue(d_tensor.place.is_gpu_place())
self.assertEqual(d_tensor.stop_gradient, True)

self.assertEqual(d_tensor.dist_attr.process_mesh, self.mesh)
self.assertEqual(d_tensor.dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh"))
self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping"))


class TestShardTensorStatic(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit b9057a5

Please sign in to comment.