Skip to content

Commit

Permalink
Fix pin_memory_fn for NamedTuple (pytorch#1086)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#1085

Per title. And, even though I can add a test, this test won't be executed as we don't have a GPU CI machine yet.

I have tested on my local machine though

Pull Request resolved: pytorch#1086

Reviewed By: NivekT

Differential Revision: D44094225

Pulled By: ejguan

fbshipit-source-id: 9c8414c31b76c93cee7e31c4e2da14076e9792bf
  • Loading branch information
ejguan authored and facebook-github-bot committed Mar 16, 2023
1 parent f2a1051 commit 1472157
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
11 changes: 10 additions & 1 deletion test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from collections import defaultdict
from functools import partial
from typing import Dict
from typing import Dict, NamedTuple

import expecttest
import torch
Expand Down Expand Up @@ -91,6 +91,11 @@ async def _async_x_mul_y(x, y):
return x * y


class NamedTensors(NamedTuple):
x: torch.Tensor
y: torch.Tensor


class TestIterDataPipe(expecttest.TestCase):
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(10))
Expand Down Expand Up @@ -1521,6 +1526,10 @@ def test_pin_memory(self):
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(v.is_pinned() for d in dp for v in d.values()))

# NamedTuple
dp = IterableWrapper([NamedTensors(torch.tensor(i), torch.tensor(i + 1)) for i in range(10)]).pin_memory()
self.assertTrue(all(v.is_pinned() for d in dp for v in d))

# Dict of List of Tensors
dp = (
IterableWrapper([{str(i): [(i - 1, i), (i, i + 1)]} for i in range(10)])
Expand Down
2 changes: 1 addition & 1 deletion torchdata/datapipes/utils/pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def pin_memory_fn(data, device=None):
elif isinstance(data, collections.abc.Sequence):
pinned_data = [pin_memory_fn(sample, device) for sample in data] # type: ignore[assignment]
try:
type(data)(*pinned_data)
return type(data)(*pinned_data)
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return pinned_data
Expand Down

0 comments on commit 1472157

Please sign in to comment.