Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/release/0.3.2' into release/0.3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 7, 2024
2 parents 9301cb6 + 6a18431 commit 3ed6e8b
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions torchrl/data/replay_buffers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import torch

from torch import Tensor

INT_CLASSES_TYPING = Union[int, np.integer]
Expand Down Expand Up @@ -87,3 +88,11 @@ def _reduce(
if isinstance(result, tuple):
result = result[0]
return result.item() if dim is None else result


def _is_int(index):
if isinstance(index, INT_CLASSES):
return True
if isinstance(index, (np.ndarray, torch.Tensor)):
return index.ndim == 0
return False

0 comments on commit 3ed6e8b

Please sign in to comment.