Skip to content

Commit

Permalink
move_to_device
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Feb 13, 2024
1 parent bdc37d8 commit 9b6eb79
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions pytorch_toolbelt/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import functools
import warnings
import logging
from typing import Optional, Sequence, Union, Dict, List, Any, Iterable, Callable
from typing import Optional, Sequence, Union, Dict, List, Any, Iterable, Callable, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -40,6 +40,7 @@
"to_numpy",
"to_tensor",
"transfer_weights",
"move_to_device",
"move_to_device_non_blocking",
"describe_outputs",
"get_collate_for_dataset",
Expand Down Expand Up @@ -328,8 +329,20 @@ def resize_like(x: Tensor, target: Tensor, mode: str = "bilinear", align_corners


def move_to_device_non_blocking(x: Tensor, device: torch.device) -> Tensor:
if x.device != device:
x = x.to(device=device, non_blocking=True)
return move_to_device(x, device, non_blocking=True)


def move_to_device(
x: Union[Tensor, List[Tensor], Tuple[Tensor, ...], Dict[Any, Tensor]], device: torch.device, non_blocking=False
) -> Tensor:
if torch.is_tensor(x):
x = x.to(device=device, non_blocking=non_blocking)
elif isinstance(x, tuple):
return tuple(move_to_device(item, device, non_blocking) for item in x)
elif isinstance(x, list):
return [move_to_device(item, device, non_blocking) for item in x]
elif isinstance(x, dict):
return {key: move_to_device(item, device, non_blocking) for key, item in x.items()}
return x


Expand Down

0 comments on commit 9b6eb79

Please sign in to comment.