From 81539eccb19a46fa84b2a206a99442a93563e974 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 5 Feb 2024 20:37:47 +0000 Subject: [PATCH] amend --- torchrl/_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 4ce56fec0b2..2f562451adb 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -711,10 +711,15 @@ class _rng_decorator(_DecoratorContextManager): def __init__(self, seed): self.seed = seed + self.stream = torch.cuda.Stream() + self.event = self.stream.record_event() def __enter__(self): self._state = torch.random.get_rng_state() torch.manual_seed(self.seed) + return torch.cuda.stream(self.stream) def __exit__(self, exc_type, exc_val, exc_tb): torch.random.set_rng_state(self._state) + self.event.wait() + self.event.synchronize()