From 9b6eb79505be5a17baec6a019a5a17fc3a743c7a Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Tue, 13 Feb 2024 11:48:06 +0200 Subject: [PATCH] move_to_device --- pytorch_toolbelt/utils/torch_utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/pytorch_toolbelt/utils/torch_utils.py b/pytorch_toolbelt/utils/torch_utils.py index 3fdc10b94..f45f0c31b 100644 --- a/pytorch_toolbelt/utils/torch_utils.py +++ b/pytorch_toolbelt/utils/torch_utils.py @@ -6,7 +6,7 @@ import functools import warnings import logging -from typing import Optional, Sequence, Union, Dict, List, Any, Iterable, Callable +from typing import Optional, Sequence, Union, Dict, List, Any, Iterable, Callable, Tuple import numpy as np import torch @@ -40,6 +40,7 @@ "to_numpy", "to_tensor", "transfer_weights", + "move_to_device", "move_to_device_non_blocking", "describe_outputs", "get_collate_for_dataset", @@ -328,8 +329,20 @@ def resize_like(x: Tensor, target: Tensor, mode: str = "bilinear", align_corners def move_to_device_non_blocking(x: Tensor, device: torch.device) -> Tensor: - if x.device != device: - x = x.to(device=device, non_blocking=True) + return move_to_device(x, device, non_blocking=True) + + +def move_to_device( + x: Union[Tensor, List[Tensor], Tuple[Tensor, ...], Dict[Any, Tensor]], device: torch.device, non_blocking=False +) -> Tensor: + if torch.is_tensor(x): + x = x.to(device=device, non_blocking=non_blocking) + elif isinstance(x, tuple): + return tuple(move_to_device(item, device, non_blocking) for item in x) + elif isinstance(x, list): + return [move_to_device(item, device, non_blocking) for item in x] + elif isinstance(x, dict): + return {key: move_to_device(item, device, non_blocking) for key, item in x.items()} return x