From 1518b4a608c6d11853d237cfb93a3ea3cbd2a495 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 19:17:50 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/data/tensor_specs.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 0f6acae9803..230f396b29b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1520,7 +1520,9 @@ def __init__( use_register: bool = False, mask: torch.Tensor | None = None, ): - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) self.use_register = use_register space = CategoricalBox(n) if shape is None: @@ -2046,7 +2048,9 @@ def __init__( if len(kwargs): raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.") - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) if dtype is None: dtype = torch.get_default_dtype() if domain is None: @@ -2644,7 +2648,9 @@ def __init__( if isinstance(shape, int): shape = _size([shape]) - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) if dtype == torch.bool: min_value = False max_value = True @@ -2851,7 +2857,9 @@ def __init__( mask: torch.Tensor | None = None, ): self.nvec = nvec - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) if shape is None: shape = _size((sum(nvec),)) else: @@ -3327,7 +3335,9 @@ def __init__( ): if shape is None: shape = _size([]) - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) space = CategoricalBox(n) super().__init__( shape=shape, space=space, device=device, dtype=dtype, domain="discrete" @@ -3874,7 +3884,9 @@ def __init__( if nvec.ndim < 1: nvec = nvec.unsqueeze(0) self.nvec = nvec - dtype, device = _default_dtype_and_device(dtype, device) + dtype, device = _default_dtype_and_device( + dtype, device, allow_none_device=False + ) if shape is None: shape = nvec.shape else: