Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents ebc44d5 + bd1af54 commit 5b50239
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
10 changes: 8 additions & 2 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from ppo_utils import eval_model, make_env, make_ppo_models
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
Expand Down Expand Up @@ -84,7 +86,11 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create data buffer
data_buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(cfg.ppo.collector.frames_per_batch, device=device),
storage=LazyTensorStorage(
cfg.ppo.collector.frames_per_batch,
device=device,
compilable=cfg.compile.compile,
),
sampler=SamplerWithoutReplacement(),
batch_size=cfg.ppo.loss.mini_batch_size,
)
Expand Down Expand Up @@ -229,7 +235,7 @@ def update(data, expert_data, num_network_updates=num_network_updates):
return TensorDict(dloss=d_loss, alpha=alpha).detach()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
update = compile_with_warmup(update, warmup=2, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
Expand Down
50 changes: 50 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,3 +851,53 @@ def set_mode(self, type: Any | None) -> None:
cm = self._lock if not is_compiling() else nullcontext()
with cm:
self._mode = type


@wraps(torch.compile)
def compile_with_warmup(*args, warmup: int, **kwargs):
"""Compile a model with warm-up.
This function wraps :func:`~torch.compile` to add a warm-up phase. During the warm-up phase,
the original model is used. After the warm-up phase, the model is compiled using
`torch.compile`.
Args:
*args: Arguments to be passed to `torch.compile`.
warmup (int): Number of calls to the model before compiling it.
**kwargs: Keyword arguments to be passed to `torch.compile`.
Returns:
A callable that wraps the original model. If no model is provided, returns a
lambda function that takes a model as input and returns the wrapped model.
Notes:
If no model is provided, this function returns a lambda function that can be
used to wrap a model later. This allows for delayed compilation of the model.
Example:
>>> model = torch.nn.Linear(5, 3)
>>> compiled_model = compile_with_warmup(model, warmup=10)
>>> # First 10 calls use the original model
>>> # After 10 calls, the model is compiled and used
"""

if len(args):
model = args[0]
else:
model = kwargs.get("model")
if model is None:
return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs)
else:
count = 0
compiled_model = model

@wraps(model)
def count_and_compile(*model_args, **model_kwargs):
nonlocal count
nonlocal compiled_model
count += 1
if count == warmup:
compiled_model = torch.compile(model, *args, **kwargs)
return compiled_model(*model_args, **model_kwargs)

return count_and_compile

0 comments on commit 5b50239

Please sign in to comment.