Skip to content

Commit 00177c0

Browse files
author
Vincent Moens
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: 36af01f Pull Request resolved: #1041
1 parent 75b33c4 commit 00177c0

File tree

5 files changed

+392
-96
lines changed

5 files changed

+392
-96
lines changed

benchmarks/common/h2d_test.py

Lines changed: 102 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,39 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import argparse
7+
from typing import Any
78

89
import pytest
910
import torch
1011
from packaging import version
1112

12-
from tensordict import TensorDict
13+
from tensordict import tensorclass, TensorDict
14+
from tensordict.utils import logger as tensordict_logger
1315

1416
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1517

1618

17-
@pytest.fixture
18-
def td():
19-
return TensorDict(
20-
{
21-
str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)}
22-
for i in range(16)
23-
},
24-
batch_size=[16],
25-
device="cpu",
26-
)
19+
@tensorclass
20+
class NJT:
21+
_values: torch.Tensor
22+
_offsets: torch.Tensor
23+
_lengths: torch.Tensor
24+
njt_shape: Any = None
25+
26+
@classmethod
27+
def from_njt(cls, njt_tensor):
28+
return NJT(
29+
_values=njt_tensor._values,
30+
_offsets=njt_tensor._offsets,
31+
_lengths=njt_tensor._lengths,
32+
njt_shape=njt_tensor.size(0),
33+
)
34+
35+
36+
@pytest.fixture(autouse=True, scope="function")
37+
def empty_compiler_cache():
38+
torch._dynamo.reset_code_caches()
39+
yield
2740

2841

2942
def _make_njt():
@@ -34,14 +47,27 @@ def _make_njt():
3447
)
3548

3649

37-
@pytest.fixture
38-
def njt_td():
50+
def _njt_td():
3951
return TensorDict(
4052
{str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
4153
device="cpu",
4254
)
4355

4456

57+
@pytest.fixture
58+
def njt_td():
59+
return _njt_td()
60+
61+
62+
@pytest.fixture
63+
def td():
64+
njtd = _njt_td()
65+
for k0, v0 in njtd.items():
66+
for k1, v1 in v0.items():
67+
njtd[k0, k1] = NJT.from_njt(v1)
68+
return njtd
69+
70+
4571
@pytest.fixture
4672
def default_device():
4773
if torch.cuda.is_available():
@@ -52,22 +78,77 @@ def default_device():
5278
pytest.skip("CUDA/MPS is not available")
5379

5480

55-
@pytest.mark.parametrize("consolidated", [False, True])
81+
@pytest.mark.parametrize(
82+
"consolidated,compile_mode,num_threads",
83+
[
84+
[False, False, None],
85+
[True, False, None],
86+
["within", False, None],
87+
# [True, False, 4],
88+
# [True, False, 16],
89+
[True, "default", None],
90+
],
91+
)
5692
@pytest.mark.skipif(
5793
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
5894
)
5995
class TestTo:
60-
def test_to(self, benchmark, consolidated, td, default_device):
61-
if consolidated:
96+
def test_to(
97+
self, benchmark, consolidated, td, default_device, compile_mode, num_threads
98+
):
99+
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
100+
if consolidated is True:
62101
td = td.consolidate()
63-
benchmark(lambda: td.to(default_device))
102+
pin_mem = default_device.type == "cuda"
103+
104+
if consolidated == "within":
105+
106+
def to(td, num_threads):
107+
return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(default_device, num_threads=num_threads)
108+
109+
else:
110+
111+
def to(td, num_threads):
112+
return td.to(default_device, num_threads=num_threads)
113+
114+
if compile_mode:
115+
to = torch.compile(to, mode=compile_mode)
116+
117+
for _ in range(3):
118+
to(td, num_threads=num_threads)
119+
120+
benchmark(to, td, num_threads)
64121

65-
def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
66-
if consolidated:
122+
def test_to_njt(
123+
self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads
124+
):
125+
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
126+
if consolidated is True:
67127
njt_td = njt_td.consolidate()
68-
benchmark(lambda: njt_td.to(default_device))
128+
pin_mem = default_device.type == "cuda"
129+
130+
if consolidated == "within":
131+
132+
def to(td, num_threads):
133+
return td.consolidate(pin_memory=pin_mem, set_on_tensor=True).to(default_device, num_threads=num_threads)
134+
135+
else:
136+
137+
def to(td, num_threads):
138+
return td.to(default_device, num_threads=num_threads)
139+
140+
if compile_mode:
141+
to = torch.compile(to, mode=compile_mode)
142+
143+
for _ in range(3):
144+
to(njt_td, num_threads=num_threads)
145+
146+
benchmark(to, njt_td, num_threads)
69147

70148

71149
if __name__ == "__main__":
72150
args, unknown = argparse.ArgumentParser().parse_known_args()
73-
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
151+
pytest.main(
152+
[__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"]
153+
+ unknown
154+
)

benchmarks/compile/compile_td_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class MyTensorClass:
2323
f: torch.Tensor
2424

2525

26+
@pytest.fixture(autouse=True, scope="function")
27+
def empty_compiler_cache():
28+
torch._dynamo.reset_code_caches()
29+
yield
30+
31+
2632
# Functions
2733
def add_one(td):
2834
return td + 1

0 commit comments

Comments
 (0)