diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 4ce56fec0b2..2f562451adb 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -711,10 +711,15 @@ class _rng_decorator(_DecoratorContextManager): def __init__(self, seed): self.seed = seed + self.stream = torch.cuda.Stream() + self.event = self.stream.record_event() def __enter__(self): self._state = torch.random.get_rng_state() torch.manual_seed(self.seed) + return torch.cuda.stream(self.stream) def __exit__(self, exc_type, exc_val, exc_tb): torch.random.set_rng_state(self._state) + self.event.wait() + self.event.synchronize()