Skip to content

Commit 2f6afc1

Browse files
committed
Report dataset download progress
1 parent 73c3b6c commit 2f6afc1

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

benchmark/dataset.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import tarfile
44
import urllib.request
55
from dataclasses import dataclass, field
6-
from typing import Dict, Optional
6+
from typing import Callable, Dict, Optional
77
from urllib.request import build_opener, install_opener
88

9+
import tqdm
10+
911
from benchmark import DATASETS_DIR
1012
from dataset_reader.ann_compound_reader import AnnCompoundReader
1113
from dataset_reader.ann_h5_reader import AnnH5Reader
@@ -54,7 +56,12 @@ def download(self):
5456

5557
if self.config.link:
5658
print(f"Downloading {self.config.link}...")
57-
tmp_path, _ = urllib.request.urlretrieve(self.config.link)
59+
with tqdm.tqdm(
60+
unit="B", unit_scale=True, miniters=1, dynamic_ncols=True, disable=None
61+
) as t:
62+
tmp_path, _ = urllib.request.urlretrieve(
63+
self.config.link, reporthook=_tqdm_reporthook(t)
64+
)
5865

5966
if self.config.link.endswith(".tgz") or self.config.link.endswith(
6067
".tar.gz"
@@ -76,6 +83,15 @@ def get_reader(self, normalize: bool) -> BaseReader:
7683
return reader_class(DATASETS_DIR / self.config.path, normalize=normalize)
7784

7885

86+
def _tqdm_reporthook(t: tqdm.tqdm) -> Callable[[int, int, int], None]:
87+
def reporthook(blocknum: int, block_size: int, total_size: int) -> None:
88+
if total_size > 0:
89+
t.total = total_size
90+
t.update(blocknum * block_size - t.n)
91+
92+
return reporthook
93+
94+
7995
if __name__ == "__main__":
8096
dataset = Dataset(
8197
{

0 commit comments

Comments
 (0)