Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 918d1fc commit d9e9921
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import torch

try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling

from tensordict import (
is_tensor_collection,
Expand Down Expand Up @@ -617,7 +617,7 @@ def _add(self, data):
return index

def _extend(self, data: Sequence) -> torch.Tensor:
is_compiling = is_dynamo_compiling()
is_compiling = is_compiling()
nc = contextlib.nullcontext()
with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc:
if self.dim_extend > 0:
Expand Down Expand Up @@ -672,7 +672,7 @@ def update_priority(

@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock if not is_dynamo_compiling() else contextlib.nullcontext():
with self._replay_lock if not is_compiling() else contextlib.nullcontext():
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage.get(index)
Expand Down Expand Up @@ -1343,7 +1343,7 @@ def sample(

@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
with self._replay_lock if not is_compiling() else contextlib.nullcontext():
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage.get(index)
Expand Down

0 comments on commit d9e9921

Please sign in to comment.