|
| 1 | +#!/usr/bin/env python3 |
| 2 | +import hashlib |
| 3 | +import os |
| 4 | + |
| 5 | +from typing import Optional |
| 6 | +from urllib.request import urlopen, Request |
| 7 | +from pathlib import Path |
| 8 | +from zipfile import ZipFile |
| 9 | + |
| 10 | +REPO_BASE_DIR = Path(__file__).absolute().parent.parent |
| 11 | +DATA_DIR = REPO_BASE_DIR / "_data" |
| 12 | +BEGINNER_DATA_DIR = REPO_BASE_DIR / "beginner_source" / "data" |
| 13 | +INTERMEDIATE_DATA_DIR = REPO_BASE_DIR / "intermediate_source" / "data" |
| 14 | +ADVANCED_DATA_DIR = REPO_BASE_DIR / "advanced_source" / "data" |
| 15 | +PROTOTYPE_DATA_DIR = REPO_BASE_DIR / "prototype_source" / "data" |
| 16 | +FILES_TO_RUN = os.getenv("FILES_TO_RUN") |
| 17 | + |
| 18 | + |
| 19 | +def size_fmt(nbytes: int) -> str: |
| 20 | + """Returns a formatted file size string""" |
| 21 | + KB = 1024 |
| 22 | + MB = 1024 * KB |
| 23 | + GB = 1024 * MB |
| 24 | + if abs(nbytes) >= GB: |
| 25 | + return f"{nbytes * 1.0 / GB:.2f} Gb" |
| 26 | + elif abs(nbytes) >= MB: |
| 27 | + return f"{nbytes * 1.0 / MB:.2f} Mb" |
| 28 | + elif abs(nbytes) >= KB: |
| 29 | + return f"{nbytes * 1.0 / KB:.2f} Kb" |
| 30 | + return str(nbytes) + " bytes" |
| 31 | + |
| 32 | + |
| 33 | +def download_url_to_file(url: str, |
| 34 | + dst: Optional[str] = None, |
| 35 | + prefix: Optional[Path] = None, |
| 36 | + sha256: Optional[str] = None) -> Path: |
| 37 | + dst = dst if dst is not None else Path(url).name |
| 38 | + dst = dst if prefix is None else str(prefix / dst) |
| 39 | + if Path(dst).exists(): |
| 40 | + print(f"Skip downloading {url} as {dst} already exists") |
| 41 | + return Path(dst) |
| 42 | + file_size = None |
| 43 | + u = urlopen(Request(url, headers={"User-Agent": "tutorials.downloader"})) |
| 44 | + meta = u.info() |
| 45 | + if hasattr(meta, 'getheaders'): |
| 46 | + content_length = meta.getheaders("Content-Length") |
| 47 | + else: |
| 48 | + content_length = meta.get_all("Content-Length") |
| 49 | + if content_length is not None and len(content_length) > 0: |
| 50 | + file_size = int(content_length[0]) |
| 51 | + sha256_sum = hashlib.sha256() |
| 52 | + with open(dst, "wb") as f: |
| 53 | + while True: |
| 54 | + buffer = u.read(32768) |
| 55 | + if len(buffer) == 0: |
| 56 | + break |
| 57 | + sha256_sum.update(buffer) |
| 58 | + f.write(buffer) |
| 59 | + digest = sha256_sum.hexdigest() |
| 60 | + if sha256 is not None and sha256 != digest: |
| 61 | + Path(dst).unlink() |
| 62 | + raise RuntimeError(f"Downloaded {url} has unexpected sha256sum {digest} should be {sha256}") |
| 63 | + print(f"Downloaded {url} sha256sum={digest} size={size_fmt(file_size)}") |
| 64 | + return Path(dst) |
| 65 | + |
| 66 | + |
| 67 | +def unzip(archive: Path, tgt_dir: Path) -> None: |
| 68 | + with ZipFile(str(archive), "r") as zip_ref: |
| 69 | + zip_ref.extractall(str(tgt_dir)) |
| 70 | + |
| 71 | + |
| 72 | +def download_hymenoptera_data(): |
| 73 | + # transfer learning tutorial data |
| 74 | + z = download_url_to_file("https://download.pytorch.org/tutorial/hymenoptera_data.zip", |
| 75 | + prefix=DATA_DIR, |
| 76 | + sha256="fbc41b31d544714d18dd1230b1e2b455e1557766e13e67f9f5a7a23af7c02209", |
| 77 | + ) |
| 78 | + unzip(z, BEGINNER_DATA_DIR) |
| 79 | + |
| 80 | + |
| 81 | +def download_nlp_data() -> None: |
| 82 | + # nlp tutorial data |
| 83 | + z = download_url_to_file("https://download.pytorch.org/tutorial/data.zip", |
| 84 | + prefix=DATA_DIR, |
| 85 | + sha256="fb317e80248faeb62dc25ef3390ae24ca34b94e276bbc5141fd8862c2200bff5", |
| 86 | + ) |
| 87 | + # This will unzip all files in data.zip to intermediate_source/data/ folder |
| 88 | + unzip(z, INTERMEDIATE_DATA_DIR.parent) |
| 89 | + |
| 90 | + |
| 91 | +def download_dcgan_data() -> None: |
| 92 | + # Download dataset for beginner_source/dcgan_faces_tutorial.py |
| 93 | + z = download_url_to_file("https://s3.amazonaws.com/pytorch-tutorial-assets/img_align_celeba.zip", |
| 94 | + prefix=DATA_DIR, |
| 95 | + sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74", |
| 96 | + ) |
| 97 | + unzip(z, BEGINNER_DATA_DIR / "celeba") |
| 98 | + |
| 99 | + |
| 100 | +def download_lenet_mnist() -> None: |
| 101 | + # Download model for beginner_source/fgsm_tutorial.py |
| 102 | + download_url_to_file("https://docs.google.com/uc?export=download&id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl", |
| 103 | + prefix=BEGINNER_DATA_DIR, |
| 104 | + dst="lenet_mnist_model.pth", |
| 105 | + sha256="cb5f8e578aef96d5c1a2cc5695e1aa9bbf4d0fe00d25760eeebaaac6ebc2edcb", |
| 106 | + ) |
| 107 | + |
| 108 | + |
| 109 | +def main() -> None: |
| 110 | + DATA_DIR.mkdir(exist_ok=True) |
| 111 | + BEGINNER_DATA_DIR.mkdir(exist_ok=True) |
| 112 | + ADVANCED_DATA_DIR.mkdir(exist_ok=True) |
| 113 | + INTERMEDIATE_DATA_DIR.mkdir(exist_ok=True) |
| 114 | + PROTOTYPE_DATA_DIR.mkdir(exist_ok=True) |
| 115 | + |
| 116 | + if FILES_TO_RUN is None or "transfer_learning_tutorial" in FILES_TO_RUN: |
| 117 | + download_hymenoptera_data() |
| 118 | + nlp_tutorials = ["seq2seq_translation_tutorial", "char_rnn_classification_tutorial", "char_rnn_generation_tutorial"] |
| 119 | + if FILES_TO_RUN is None or any(x in FILES_TO_RUN for x in nlp_tutorials): |
| 120 | + download_nlp_data() |
| 121 | + if FILES_TO_RUN is None or "dcgan_faces_tutorial" in FILES_TO_RUN: |
| 122 | + download_dcgan_data() |
| 123 | + if FILES_TO_RUN is None or "fgsm_tutorial" in FILES_TO_RUN: |
| 124 | + download_lenet_mnist() |
| 125 | + |
| 126 | + |
| 127 | +if __name__ == "__main__": |
| 128 | + main() |
0 commit comments