Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit a95239f

Browse files
author
DEKHTIARJonathan
committed
[Benchmarking-Py] Fix Synthetic Dataset with DS returning x, y
1 parent 7c0f462 commit a95239f

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

tftrt/benchmarking-python/dataloading_utils.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@
1414
def SyntheticDataset(dataset, device):
1515
data_batch = next(iter(dataset))
1616

17-
def copy_on_device(t):
18-
19-
if t.dtype != tf.int32:
20-
with tf.device(device):
21-
return tf.identity(t)
22-
23-
return t
24-
25-
if isinstance(data_batch, (tuple, list)):
26-
data_batch = [copy_on_device(t) for t in data_batch]
27-
elif isinstance(data_batch, dict):
28-
data_batch = {k: copy_on_device(t) for k, t in data_batch.items()}
29-
else:
30-
data_batch = copy_on_device(data_batch)
17+
def copy_on_device(data):
18+
if isinstance(data, (tuple, list)):
19+
return [copy_on_device(t) for t in data]
20+
elif isinstance(data, dict):
21+
return {k: copy_on_device(t) for k, t in data.items()}
22+
else:
23+
try:
24+
if data.dtype != tf.int32:
25+
with tf.device(device):
26+
return tf.identity(data)
27+
except AttributeError:
28+
pass
29+
30+
return data
31+
32+
data_batch = copy_on_device(data_batch)
3133

3234
return itertools.repeat(data_batch)
3335

0 commit comments

Comments
 (0)