From 5259e6eb934849be93a85e98f1a4eda5c8dd86eb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 19:12:38 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/gail/gail.py | 10 +++++-- torchrl/_utils.py | 50 +++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 969c7fc083e..48e3b18c2c3 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -22,6 +22,8 @@ from ppo_utils import eval_model, make_env, make_ppo_models from tensordict import TensorDict from tensordict.nn import CudaGraphModule + +from torchrl._utils import compile_with_warmup from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -84,7 +86,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer data_buffer = TensorDictReplayBuffer( - storage=LazyTensorStorage(cfg.ppo.collector.frames_per_batch, device=device), + storage=LazyTensorStorage( + cfg.ppo.collector.frames_per_batch, + device=device, + compilable=cfg.compile.compile, + ), sampler=SamplerWithoutReplacement(), batch_size=cfg.ppo.loss.mini_batch_size, ) @@ -229,7 +235,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): return TensorDict(dloss=d_loss, alpha=alpha).detach() if cfg.compile.compile: - update = torch.compile(update, mode=compile_mode) + update = compile_with_warmup(update, warmup=2, mode=compile_mode) if cfg.compile.cudagraphs: warnings.warn( "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", diff --git a/torchrl/_utils.py b/torchrl/_utils.py index c81ffcc962b..73f31c8ccf5 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -851,3 +851,53 @@ def set_mode(self, type: Any | None) -> None: cm = self._lock if not is_compiling() else nullcontext() with cm: self._mode = type + + +@wraps(torch.compile) +def compile_with_warmup(*args, warmup: int, **kwargs): + """Compile a model with warm-up. + + This function wraps :func:`~torch.compile` to add a warm-up phase. During the warm-up phase, + the original model is used. After the warm-up phase, the model is compiled using + `torch.compile`. + + Args: + *args: Arguments to be passed to `torch.compile`. + warmup (int): Number of calls to the model before compiling it. + **kwargs: Keyword arguments to be passed to `torch.compile`. + + Returns: + A callable that wraps the original model. If no model is provided, returns a + lambda function that takes a model as input and returns the wrapped model. + + Notes: + If no model is provided, this function returns a lambda function that can be + used to wrap a model later. This allows for delayed compilation of the model. + + Example: + >>> model = torch.nn.Linear(5, 3) + >>> compiled_model = compile_with_warmup(model, warmup=10) + >>> # First 10 calls use the original model + >>> # After 10 calls, the model is compiled and used + """ + + if len(args): + model = args[0] + else: + model = kwargs.get("model") + if model is None: + return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs) + else: + count = 0 + compiled_model = model + + @wraps(model) + def count_and_compile(*model_args, **model_kwargs): + nonlocal count + nonlocal compiled_model + count += 1 + if count == warmup: + compiled_model = torch.compile(model, *args, **kwargs) + return compiled_model(*model_args, **model_kwargs) + + return count_and_compile