Skip to content

Commit 0611c42

Browse files
author
Vincent Moens
committed
[Feature] A2C compatibility with compile
ghstack-source-id: 6f2f140 Pull Request resolved: #2464
1 parent d894358 commit 0611c42

File tree

18 files changed

+491
-282
lines changed

18 files changed

+491
-282
lines changed

benchmarks/test_objectives_benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
) # Anything from 2.5, incl. nightlies, allows for fullgraph
5151

5252

53-
@pytest.fixture(scope="module")
53+
@pytest.fixture(scope="module", autouse=True)
5454
def set_default_device():
5555
cur_device = torch.get_default_device()
5656
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

sota-implementations/a2c/a2c_atari.py

Lines changed: 99 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import hydra
6+
from tensordict.nn import CudaGraphModule
67
from torchrl._utils import logger as torchrl_logger
78
from torchrl.record import VideoRecorder
89

@@ -15,17 +16,21 @@ def main(cfg: "DictConfig"): # noqa: F821
1516
import torch.optim
1617
import tqdm
1718

18-
from tensordict import TensorDict
19+
from torchrl._utils import timeit
1920
from torchrl.collectors import SyncDataCollector
20-
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
21+
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
2122
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
2223
from torchrl.envs import ExplorationType, set_exploration_type
2324
from torchrl.objectives import A2CLoss
2425
from torchrl.objectives.value.advantages import GAE
2526
from torchrl.record.loggers import generate_exp_name, get_logger
2627
from utils_atari import eval_model, make_parallel_env, make_ppo_models
2728

28-
device = "cpu" if not torch.cuda.device_count() else "cuda"
29+
device = cfg.loss.device
30+
if not device:
31+
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
32+
else:
33+
device = torch.device(device)
2934

3035
# Correct for frame_skip
3136
frame_skip = 4
@@ -35,28 +40,12 @@ def main(cfg: "DictConfig"): # noqa: F821
3540
test_interval = cfg.logger.test_interval // frame_skip
3641

3742
# Create models (check utils_atari.py)
38-
actor, critic, critic_head = make_ppo_models(cfg.env.env_name)
39-
actor, critic, critic_head = (
40-
actor.to(device),
41-
critic.to(device),
42-
critic_head.to(device),
43-
)
44-
45-
# Create collector
46-
collector = SyncDataCollector(
47-
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
48-
policy=actor,
49-
frames_per_batch=frames_per_batch,
50-
total_frames=total_frames,
51-
device=device,
52-
storing_device=device,
53-
max_frames_per_traj=-1,
54-
)
43+
actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device)
5544

5645
# Create data buffer
5746
sampler = SamplerWithoutReplacement()
5847
data_buffer = TensorDictReplayBuffer(
59-
storage=LazyMemmapStorage(frames_per_batch),
48+
storage=LazyTensorStorage(frames_per_batch, device=device),
6049
sampler=sampler,
6150
batch_size=mini_batch_size,
6251
)
@@ -67,6 +56,7 @@ def main(cfg: "DictConfig"): # noqa: F821
6756
lmbda=cfg.loss.gae_lambda,
6857
value_network=critic,
6958
average_gae=True,
59+
vectorized=not cfg.loss.compile,
7060
)
7161
loss_module = A2CLoss(
7262
actor_network=actor,
@@ -83,9 +73,10 @@ def main(cfg: "DictConfig"): # noqa: F821
8373
# Create optimizer
8474
optim = torch.optim.Adam(
8575
loss_module.parameters(),
86-
lr=cfg.optim.lr,
76+
lr=torch.tensor(cfg.optim.lr, device=device),
8777
weight_decay=cfg.optim.weight_decay,
8878
eps=cfg.optim.eps,
79+
capturable=device.type == "cuda",
8980
)
9081

9182
# Create logger
@@ -115,16 +106,72 @@ def main(cfg: "DictConfig"): # noqa: F821
115106
)
116107
test_env.eval()
117108

109+
# update function
110+
def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
111+
# Forward pass A2C loss
112+
loss = loss_module(batch)
113+
114+
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
115+
116+
# Backward pass
117+
loss_sum.backward()
118+
gn = torch.nn.utils.clip_grad_norm_(
119+
loss_module.parameters(), max_norm=max_grad_norm
120+
)
121+
122+
# Update the networks
123+
optim.step()
124+
optim.zero_grad(set_to_none=True)
125+
126+
return (
127+
loss.select("loss_critic", "loss_entropy", "loss_objective")
128+
.detach()
129+
.set("grad_norm", gn)
130+
)
131+
132+
compile_mode = None
133+
if cfg.loss.compile:
134+
compile_mode = cfg.loss.compile_mode
135+
if compile_mode in ("", None):
136+
if cfg.loss.cudagraphs:
137+
compile_mode = "default"
138+
else:
139+
compile_mode = "reduce-overhead"
140+
update = torch.compile(update, mode=compile_mode)
141+
adv_module = torch.compile(adv_module, mode=compile_mode)
142+
143+
if cfg.loss.cudagraphs:
144+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
145+
adv_module = CudaGraphModule(adv_module)
146+
147+
# Create collector
148+
collector = SyncDataCollector(
149+
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
150+
policy=actor,
151+
frames_per_batch=frames_per_batch,
152+
total_frames=total_frames,
153+
device=device,
154+
storing_device=device,
155+
policy_device=device,
156+
compile_policy={"mode": compile_mode} if cfg.loss.compile else False,
157+
cudagraph_policy=cfg.loss.cudagraphs,
158+
)
159+
118160
# Main loop
119161
collected_frames = 0
120162
num_network_updates = 0
121163
start_time = time.time()
122164
pbar = tqdm.tqdm(total=total_frames)
123165
num_mini_batches = frames_per_batch // mini_batch_size
124166
total_network_updates = (total_frames // frames_per_batch) * num_mini_batches
167+
lr = cfg.optim.lr
125168

126169
sampling_start = time.time()
127-
for i, data in enumerate(collector):
170+
c_iter = iter(collector)
171+
for i in range(len(collector)):
172+
with timeit("collecting"):
173+
torch.compiler.cudagraph_mark_step_begin()
174+
data = next(c_iter)
128175

129176
log_info = {}
130177
sampling_time = time.time() - sampling_start
@@ -144,61 +191,55 @@ def main(cfg: "DictConfig"): # noqa: F821
144191
}
145192
)
146193

147-
losses = TensorDict({}, batch_size=[num_mini_batches])
194+
losses = []
148195
training_start = time.time()
149196

150197
# Compute GAE
151-
with torch.no_grad():
198+
with torch.no_grad(), timeit("advantage"):
152199
data = adv_module(data)
153200
data_reshape = data.reshape(-1)
154201

155202
# Update the data buffer
156-
data_buffer.extend(data_reshape)
157-
158-
for k, batch in enumerate(data_buffer):
159-
160-
# Get a data batch
161-
batch = batch.to(device)
162-
163-
# Linearly decrease the learning rate and clip epsilon
164-
alpha = 1.0
165-
if cfg.optim.anneal_lr:
166-
alpha = 1 - (num_network_updates / total_network_updates)
167-
for group in optim.param_groups:
168-
group["lr"] = cfg.optim.lr * alpha
169-
num_network_updates += 1
170-
171-
# Forward pass A2C loss
172-
loss = loss_module(batch)
173-
losses[k] = loss.select(
174-
"loss_critic", "loss_entropy", "loss_objective"
175-
).detach()
176-
loss_sum = (
177-
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
178-
)
203+
with timeit("emptying"):
204+
data_buffer.empty()
205+
with timeit("extending"):
206+
data_buffer.extend(data_reshape)
179207

180-
# Backward pass
181-
loss_sum.backward()
182-
torch.nn.utils.clip_grad_norm_(
183-
list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm
184-
)
208+
with timeit("optim"):
209+
for batch in data_buffer:
210+
211+
# Linearly decrease the learning rate and clip epsilon
212+
with timeit("optim - lr"):
213+
alpha = 1.0
214+
if cfg.optim.anneal_lr:
215+
alpha = 1 - (num_network_updates / total_network_updates)
216+
for group in optim.param_groups:
217+
group["lr"].copy_(lr * alpha)
185218

186-
# Update the networks
187-
optim.step()
188-
optim.zero_grad()
219+
num_network_updates += 1
220+
221+
with timeit("optim - update"):
222+
torch.compiler.cudagraph_mark_step_begin()
223+
loss = update(batch)
224+
losses.append(loss)
189225

190226
# Get training losses
191227
training_time = time.time() - training_start
192-
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
228+
losses = torch.stack(losses).float().mean()
229+
193230
for key, value in losses.items():
194231
log_info.update({f"train/{key}": value.item()})
195232
log_info.update(
196233
{
197-
"train/lr": alpha * cfg.optim.lr,
234+
"train/lr": lr * alpha,
198235
"train/sampling_time": sampling_time,
199236
"train/training_time": training_time,
237+
**timeit.todict(prefix="time"),
200238
}
201239
)
240+
if i % 200 == 0:
241+
timeit.print()
242+
timeit.erase()
202243

203244
# Get test rewards
204245
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
@@ -223,7 +264,6 @@ def main(cfg: "DictConfig"): # noqa: F821
223264
for key, value in log_info.items():
224265
logger.log_scalar(key, value, collected_frames)
225266

226-
collector.update_policy_weights_()
227267
sampling_start = time.time()
228268

229269
collector.shutdown()

0 commit comments

Comments
 (0)