4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import argparse
7
+ from typing import Any
7
8
8
9
import pytest
9
10
import torch
10
11
from packaging import version
11
12
12
- from tensordict import TensorDict
13
+ from tensordict import tensorclass , TensorDict
14
+ from tensordict .utils import logger as tensordict_logger
13
15
14
16
TORCH_VERSION = version .parse (version .parse (torch .__version__ ).base_version )
15
17
16
18
17
- @pytest .fixture
18
- def td ():
19
- return TensorDict (
20
- {
21
- str (i ): {str (j ): torch .randn (16 , 16 , device = "cpu" ) for j in range (16 )}
22
- for i in range (16 )
23
- },
24
- batch_size = [16 ],
25
- device = "cpu" ,
26
- )
19
+ @tensorclass
20
+ class NJT :
21
+ _values : torch .Tensor
22
+ _offsets : torch .Tensor
23
+ _lengths : torch .Tensor
24
+ njt_shape : Any = None
25
+
26
+ @classmethod
27
+ def from_njt (cls , njt_tensor ):
28
+ return NJT (
29
+ _values = njt_tensor ._values ,
30
+ _offsets = njt_tensor ._offsets ,
31
+ _lengths = njt_tensor ._lengths ,
32
+ njt_shape = njt_tensor .size (0 ),
33
+ )
34
+
35
+
36
+ @pytest .fixture (autouse = True , scope = "function" )
37
+ def empty_compiler_cache ():
38
+ torch ._dynamo .reset_code_caches ()
39
+ yield
27
40
28
41
29
42
def _make_njt ():
@@ -34,14 +47,27 @@ def _make_njt():
34
47
)
35
48
36
49
37
- @pytest .fixture
38
- def njt_td ():
50
+ def _njt_td ():
39
51
return TensorDict (
40
52
{str (i ): {str (j ): _make_njt () for j in range (32 )} for i in range (32 )},
41
53
device = "cpu" ,
42
54
)
43
55
44
56
57
+ @pytest .fixture
58
+ def njt_td ():
59
+ return _njt_td ()
60
+
61
+
62
+ @pytest .fixture
63
+ def td ():
64
+ njtd = _njt_td ()
65
+ for k0 , v0 in njtd .items ():
66
+ for k1 , v1 in v0 .items ():
67
+ njtd [k0 , k1 ] = NJT .from_njt (v1 )
68
+ return njtd
69
+
70
+
45
71
@pytest .fixture
46
72
def default_device ():
47
73
if torch .cuda .is_available ():
@@ -52,22 +78,77 @@ def default_device():
52
78
pytest .skip ("CUDA/MPS is not available" )
53
79
54
80
55
- @pytest .mark .parametrize ("consolidated" , [False , True ])
81
+ @pytest .mark .parametrize (
82
+ "consolidated,compile_mode,num_threads" ,
83
+ [
84
+ [False , False , None ],
85
+ [True , False , None ],
86
+ ["within" , False , None ],
87
+ # [True, False, 4],
88
+ # [True, False, 16],
89
+ [True , "default" , None ],
90
+ ],
91
+ )
56
92
@pytest .mark .skipif (
57
93
TORCH_VERSION < version .parse ("2.5.0" ), reason = "requires torch>=2.5"
58
94
)
59
95
class TestTo :
60
- def test_to (self , benchmark , consolidated , td , default_device ):
61
- if consolidated :
96
+ def test_to (
97
+ self , benchmark , consolidated , td , default_device , compile_mode , num_threads
98
+ ):
99
+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
100
+ if consolidated is True :
62
101
td = td .consolidate ()
63
- benchmark (lambda : td .to (default_device ))
102
+ pin_mem = default_device .type == "cuda"
103
+
104
+ if consolidated == "within" :
105
+
106
+ def to (td , num_threads ):
107
+ return td .consolidate (pin_memory = pin_mem , set_on_tensor = True ).to (default_device , num_threads = num_threads )
108
+
109
+ else :
110
+
111
+ def to (td , num_threads ):
112
+ return td .to (default_device , num_threads = num_threads )
113
+
114
+ if compile_mode :
115
+ to = torch .compile (to , mode = compile_mode )
116
+
117
+ for _ in range (3 ):
118
+ to (td , num_threads = num_threads )
119
+
120
+ benchmark (to , td , num_threads )
64
121
65
- def test_to_njt (self , benchmark , consolidated , njt_td , default_device ):
66
- if consolidated :
122
+ def test_to_njt (
123
+ self , benchmark , consolidated , njt_td , default_device , compile_mode , num_threads
124
+ ):
125
+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
126
+ if consolidated is True :
67
127
njt_td = njt_td .consolidate ()
68
- benchmark (lambda : njt_td .to (default_device ))
128
+ pin_mem = default_device .type == "cuda"
129
+
130
+ if consolidated == "within" :
131
+
132
+ def to (td , num_threads ):
133
+ return td .consolidate (pin_memory = pin_mem , set_on_tensor = True ).to (default_device , num_threads = num_threads )
134
+
135
+ else :
136
+
137
+ def to (td , num_threads ):
138
+ return td .to (default_device , num_threads = num_threads )
139
+
140
+ if compile_mode :
141
+ to = torch .compile (to , mode = compile_mode )
142
+
143
+ for _ in range (3 ):
144
+ to (njt_td , num_threads = num_threads )
145
+
146
+ benchmark (to , njt_td , num_threads )
69
147
70
148
71
149
if __name__ == "__main__" :
72
150
args , unknown = argparse .ArgumentParser ().parse_known_args ()
73
- pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
151
+ pytest .main (
152
+ [__file__ , "--capture" , "no" , "--exitfirst" , "--benchmark-group-by" , "func" ]
153
+ + unknown
154
+ )
0 commit comments