Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 46393fa

Browse files
author
DEKHTIARJonathan
committed
Synthetic Dataset Workload Overhauled
1 parent 5a3666a commit 46393fa

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

tftrt/examples/benchmark_runner.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from benchmark_utils import print_dict
1919
from benchmark_utils import timed_section
2020
from benchmark_utils import timed_dataset
21+
from dataloading_utils import SyntheticDataset
2122

2223
import numpy as np
2324
import tensorflow as tf
@@ -333,12 +334,7 @@ def execute_benchmark(self):
333334
if self._args.use_synthetic_data:
334335
old_ds = dataset
335336
try:
336-
dataset = dataset.take(count=1) # loop over 1 batch
337-
dataset = dataset.cache()
338-
dataset = dataset.repeat()
339-
dataset = dataset.prefetch(
340-
buffer_size=tf.data.experimental.AUTOTUNE
341-
)
337+
dataset = SyntheticDataset(old_ds, device="/gpu:0")
342338
self._debug_print(
343339
"Model dataset has been replaced by a synthetic data "
344340
"loader to minimize data loading jitter."
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# -*- coding: utf-8 -*-
4+
5+
import tensorflow as tf
6+
7+
8+
class SyntheticDataset(object):
9+
def __iter__(self):
10+
data = 0
11+
12+
def __init__(self, dataset, device):
13+
14+
self._ds_iter = iter(dataset)
15+
self._device = device
16+
17+
def __iter__(self):
18+
19+
with tf.device(self._device):
20+
21+
data_batch = None
22+
tf.random.set_seed(666)
23+
24+
def get_random_tensor(t_shape, t_dtype):
25+
if t_dtype == tf.bool:
26+
return (
27+
tf.random.uniform(shape=t_shape, dtype=tf.float32) < 0.5
28+
)
29+
30+
elif t_dtype in [tf.int32, tf.int64]:
31+
return tf.random.uniform(
32+
shape=t_shape, dtype=t_dtype, maxval=5
33+
)
34+
35+
else:
36+
return tf.random.uniform(shape=t_shape, dtype=t_dtype)
37+
38+
ds_batch = next(self._ds_iter)
39+
40+
if isinstance(ds_batch, (list, tuple)):
41+
data_batch = list()
42+
for t in ds_batch:
43+
data_batch.append(get_random_tensor(t.shape, t.dtype))
44+
45+
elif isinstance(ds_batch, dict):
46+
data_batch = dict()
47+
for k, v in ds_batch.items():
48+
data_batch[k] = get_random_tensor(v.shape, v.dtype)
49+
50+
else:
51+
data_batch = get_random_tensor(ds_batch.shape, ds_batch.dtype)
52+
53+
while True:
54+
yield data_batch

0 commit comments

Comments
 (0)