3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
import hydra
6
+ from tensordict .nn import CudaGraphModule
6
7
from torchrl ._utils import logger as torchrl_logger
7
8
from torchrl .record import VideoRecorder
8
9
@@ -15,17 +16,21 @@ def main(cfg: "DictConfig"): # noqa: F821
15
16
import torch .optim
16
17
import tqdm
17
18
18
- from tensordict import TensorDict
19
+ from torchrl . _utils import timeit
19
20
from torchrl .collectors import SyncDataCollector
20
- from torchrl .data import LazyMemmapStorage , TensorDictReplayBuffer
21
+ from torchrl .data import LazyTensorStorage , TensorDictReplayBuffer
21
22
from torchrl .data .replay_buffers .samplers import SamplerWithoutReplacement
22
23
from torchrl .envs import ExplorationType , set_exploration_type
23
24
from torchrl .objectives import A2CLoss
24
25
from torchrl .objectives .value .advantages import GAE
25
26
from torchrl .record .loggers import generate_exp_name , get_logger
26
27
from utils_atari import eval_model , make_parallel_env , make_ppo_models
27
28
28
- device = "cpu" if not torch .cuda .device_count () else "cuda"
29
+ device = cfg .loss .device
30
+ if not device :
31
+ device = torch .device ("cpu" if not torch .cuda .is_available () else "cuda:0" )
32
+ else :
33
+ device = torch .device (device )
29
34
30
35
# Correct for frame_skip
31
36
frame_skip = 4
@@ -35,28 +40,12 @@ def main(cfg: "DictConfig"): # noqa: F821
35
40
test_interval = cfg .logger .test_interval // frame_skip
36
41
37
42
# Create models (check utils_atari.py)
38
- actor , critic , critic_head = make_ppo_models (cfg .env .env_name )
39
- actor , critic , critic_head = (
40
- actor .to (device ),
41
- critic .to (device ),
42
- critic_head .to (device ),
43
- )
44
-
45
- # Create collector
46
- collector = SyncDataCollector (
47
- create_env_fn = make_parallel_env (cfg .env .env_name , cfg .env .num_envs , device ),
48
- policy = actor ,
49
- frames_per_batch = frames_per_batch ,
50
- total_frames = total_frames ,
51
- device = device ,
52
- storing_device = device ,
53
- max_frames_per_traj = - 1 ,
54
- )
43
+ actor , critic , critic_head = make_ppo_models (cfg .env .env_name , device = device )
55
44
56
45
# Create data buffer
57
46
sampler = SamplerWithoutReplacement ()
58
47
data_buffer = TensorDictReplayBuffer (
59
- storage = LazyMemmapStorage (frames_per_batch ),
48
+ storage = LazyTensorStorage (frames_per_batch , device = device ),
60
49
sampler = sampler ,
61
50
batch_size = mini_batch_size ,
62
51
)
@@ -67,6 +56,7 @@ def main(cfg: "DictConfig"): # noqa: F821
67
56
lmbda = cfg .loss .gae_lambda ,
68
57
value_network = critic ,
69
58
average_gae = True ,
59
+ vectorized = not cfg .loss .compile ,
70
60
)
71
61
loss_module = A2CLoss (
72
62
actor_network = actor ,
@@ -83,9 +73,10 @@ def main(cfg: "DictConfig"): # noqa: F821
83
73
# Create optimizer
84
74
optim = torch .optim .Adam (
85
75
loss_module .parameters (),
86
- lr = cfg .optim .lr ,
76
+ lr = torch . tensor ( cfg .optim .lr , device = device ) ,
87
77
weight_decay = cfg .optim .weight_decay ,
88
78
eps = cfg .optim .eps ,
79
+ capturable = device .type == "cuda" ,
89
80
)
90
81
91
82
# Create logger
@@ -115,16 +106,72 @@ def main(cfg: "DictConfig"): # noqa: F821
115
106
)
116
107
test_env .eval ()
117
108
109
+ # update function
110
+ def update (batch , max_grad_norm = cfg .optim .max_grad_norm ):
111
+ # Forward pass A2C loss
112
+ loss = loss_module (batch )
113
+
114
+ loss_sum = loss ["loss_critic" ] + loss ["loss_objective" ] + loss ["loss_entropy" ]
115
+
116
+ # Backward pass
117
+ loss_sum .backward ()
118
+ gn = torch .nn .utils .clip_grad_norm_ (
119
+ loss_module .parameters (), max_norm = max_grad_norm
120
+ )
121
+
122
+ # Update the networks
123
+ optim .step ()
124
+ optim .zero_grad (set_to_none = True )
125
+
126
+ return (
127
+ loss .select ("loss_critic" , "loss_entropy" , "loss_objective" )
128
+ .detach ()
129
+ .set ("grad_norm" , gn )
130
+ )
131
+
132
+ compile_mode = None
133
+ if cfg .loss .compile :
134
+ compile_mode = cfg .loss .compile_mode
135
+ if compile_mode in ("" , None ):
136
+ if cfg .loss .cudagraphs :
137
+ compile_mode = "default"
138
+ else :
139
+ compile_mode = "reduce-overhead"
140
+ update = torch .compile (update , mode = compile_mode )
141
+ adv_module = torch .compile (adv_module , mode = compile_mode )
142
+
143
+ if cfg .loss .cudagraphs :
144
+ update = CudaGraphModule (update , in_keys = [], out_keys = [], warmup = 5 )
145
+ adv_module = CudaGraphModule (adv_module )
146
+
147
+ # Create collector
148
+ collector = SyncDataCollector (
149
+ create_env_fn = make_parallel_env (cfg .env .env_name , cfg .env .num_envs , device ),
150
+ policy = actor ,
151
+ frames_per_batch = frames_per_batch ,
152
+ total_frames = total_frames ,
153
+ device = device ,
154
+ storing_device = device ,
155
+ policy_device = device ,
156
+ compile_policy = {"mode" : compile_mode } if cfg .loss .compile else False ,
157
+ cudagraph_policy = cfg .loss .cudagraphs ,
158
+ )
159
+
118
160
# Main loop
119
161
collected_frames = 0
120
162
num_network_updates = 0
121
163
start_time = time .time ()
122
164
pbar = tqdm .tqdm (total = total_frames )
123
165
num_mini_batches = frames_per_batch // mini_batch_size
124
166
total_network_updates = (total_frames // frames_per_batch ) * num_mini_batches
167
+ lr = cfg .optim .lr
125
168
126
169
sampling_start = time .time ()
127
- for i , data in enumerate (collector ):
170
+ c_iter = iter (collector )
171
+ for i in range (len (collector )):
172
+ with timeit ("collecting" ):
173
+ torch .compiler .cudagraph_mark_step_begin ()
174
+ data = next (c_iter )
128
175
129
176
log_info = {}
130
177
sampling_time = time .time () - sampling_start
@@ -144,61 +191,55 @@ def main(cfg: "DictConfig"): # noqa: F821
144
191
}
145
192
)
146
193
147
- losses = TensorDict ({}, batch_size = [ num_mini_batches ])
194
+ losses = []
148
195
training_start = time .time ()
149
196
150
197
# Compute GAE
151
- with torch .no_grad ():
198
+ with torch .no_grad (), timeit ( "advantage" ) :
152
199
data = adv_module (data )
153
200
data_reshape = data .reshape (- 1 )
154
201
155
202
# Update the data buffer
156
- data_buffer .extend (data_reshape )
157
-
158
- for k , batch in enumerate (data_buffer ):
159
-
160
- # Get a data batch
161
- batch = batch .to (device )
162
-
163
- # Linearly decrease the learning rate and clip epsilon
164
- alpha = 1.0
165
- if cfg .optim .anneal_lr :
166
- alpha = 1 - (num_network_updates / total_network_updates )
167
- for group in optim .param_groups :
168
- group ["lr" ] = cfg .optim .lr * alpha
169
- num_network_updates += 1
170
-
171
- # Forward pass A2C loss
172
- loss = loss_module (batch )
173
- losses [k ] = loss .select (
174
- "loss_critic" , "loss_entropy" , "loss_objective"
175
- ).detach ()
176
- loss_sum = (
177
- loss ["loss_critic" ] + loss ["loss_objective" ] + loss ["loss_entropy" ]
178
- )
203
+ with timeit ("emptying" ):
204
+ data_buffer .empty ()
205
+ with timeit ("extending" ):
206
+ data_buffer .extend (data_reshape )
179
207
180
- # Backward pass
181
- loss_sum .backward ()
182
- torch .nn .utils .clip_grad_norm_ (
183
- list (loss_module .parameters ()), max_norm = cfg .optim .max_grad_norm
184
- )
208
+ with timeit ("optim" ):
209
+ for batch in data_buffer :
210
+
211
+ # Linearly decrease the learning rate and clip epsilon
212
+ with timeit ("optim - lr" ):
213
+ alpha = 1.0
214
+ if cfg .optim .anneal_lr :
215
+ alpha = 1 - (num_network_updates / total_network_updates )
216
+ for group in optim .param_groups :
217
+ group ["lr" ].copy_ (lr * alpha )
185
218
186
- # Update the networks
187
- optim .step ()
188
- optim .zero_grad ()
219
+ num_network_updates += 1
220
+
221
+ with timeit ("optim - update" ):
222
+ torch .compiler .cudagraph_mark_step_begin ()
223
+ loss = update (batch )
224
+ losses .append (loss )
189
225
190
226
# Get training losses
191
227
training_time = time .time () - training_start
192
- losses = losses .apply (lambda x : x .float ().mean (), batch_size = [])
228
+ losses = torch .stack (losses ).float ().mean ()
229
+
193
230
for key , value in losses .items ():
194
231
log_info .update ({f"train/{ key } " : value .item ()})
195
232
log_info .update (
196
233
{
197
- "train/lr" : alpha * cfg . optim . lr ,
234
+ "train/lr" : lr * alpha ,
198
235
"train/sampling_time" : sampling_time ,
199
236
"train/training_time" : training_time ,
237
+ ** timeit .todict (prefix = "time" ),
200
238
}
201
239
)
240
+ if i % 200 == 0 :
241
+ timeit .print ()
242
+ timeit .erase ()
202
243
203
244
# Get test rewards
204
245
with torch .no_grad (), set_exploration_type (ExplorationType .DETERMINISTIC ):
@@ -223,7 +264,6 @@ def main(cfg: "DictConfig"): # noqa: F821
223
264
for key , value in log_info .items ():
224
265
logger .log_scalar (key , value , collected_frames )
225
266
226
- collector .update_policy_weights_ ()
227
267
sampling_start = time .time ()
228
268
229
269
collector .shutdown ()
0 commit comments