From c25408ac1b1503b98d150ad32f7c3c5ee0d01aca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Wed, 4 Jan 2023 12:54:19 -0500 Subject: [PATCH] ENH: Add data fetcher Add data fetcher. Add the accompanying test. Add a script to fetch a given dataset to the provided local path. Add the accompanying test. Add the script to the project scripts in the `pyproject.toml` file. Add the badge corresponding to the Zenodo record DOI to the `README`. Explain in the `README` file how to: - Fetch the data automatically. - Cite the data. Enable the package testing and coverage report steps in the `GitHub Actions` workflow file. --- .github/workflows/test_package.yml | 24 +- README.md | 21 + pyproject.toml | 2 + scripts/fetch_data.py | 41 ++ scripts/tests/__init__.py | 0 scripts/tests/test_fetch_data.py | 47 ++ tractolearn/tractoio/dataset_fetch.py | 539 ++++++++++++++++++ tractolearn/tractoio/tests/__init__.py | 0 .../tractoio/tests/test_dataset_fetch.py | 24 + 9 files changed, 686 insertions(+), 12 deletions(-) create mode 100755 scripts/fetch_data.py create mode 100644 scripts/tests/__init__.py create mode 100644 scripts/tests/test_fetch_data.py create mode 100644 tractolearn/tractoio/dataset_fetch.py create mode 100644 tractolearn/tractoio/tests/__init__.py create mode 100644 tractolearn/tractoio/tests/test_dataset_fetch.py diff --git a/.github/workflows/test_package.yml b/.github/workflows/test_package.yml index 9e2dd23..aef1ea0 100644 --- a/.github/workflows/test_package.yml +++ b/.github/workflows/test_package.yml @@ -71,20 +71,20 @@ jobs: run: | # tox --sitepackages python -c 'import tractolearn' - # coverage run --source tractolearn -m pytest tractolearn -o junit_family=xunit2 -v --doctest-modules --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml + coverage run --source tractolearn -m pytest tractolearn -o junit_family=xunit2 -v --doctest-modules --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml - #- name: Upload pytest test results - # uses: actions/upload-artifact@master - # with: - # name: pytest-results-${{ runner.os }}-${{ matrix.python-version }} - # path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml - # # Use always() to always run this step to publish test results when there are test failures - # if: always() + - name: Upload pytest test results + uses: actions/upload-artifact@master + with: + name: pytest-results-${{ runner.os }}-${{ matrix.python-version }} + path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml + # Use always() to always run this step to publish test results when there are test failures + if: always() - #- name: Statistics - # if: success() - # run: | - # coverage report + - name: Statistics + if: success() + run: | + coverage report - name: Package Setup # - name: Run tests with tox diff --git a/README.md b/README.md index 5cc1957..72b0955 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ [![test, package](https://github.com/scil-vital/tractolearn/actions/workflows/test_package.yml/badge.svg?branch=main)](https://github.com/scil-vital/tractolearn/actions/workflows/test_package.yml?query=branch%3Amain) [![documentation](https://readthedocs.org/projects/tractolearn/badge/?version=latest)](https://tractolearn.readthedocs.io/en/latest/?badge=latest) +[![DOI tractolearn](https://zenodo.org/badge/DOI/10.5281/zenodo.7562790.svg)](https://doi.org/10.5281/zenodo.7562790) +[![DOI RBX](https://zenodo.org/badge/DOI/10.5281/zenodo.7562635.svg)](https://doi.org/10.5281/zenodo.7562635) Tractography learning. @@ -53,6 +55,22 @@ training pipeline with the following command: ae_train.py train_config.yaml -vv ``` +## Data + +To automatically fetch or use the [tractolearn data](https://zenodo.org/record/7562790) +provided, you can use the `retrieve_dataset` method located in the +`tractolearn.tractoio.dataset_fetch` module, or the `dataset_fetch` script, +e.g.: +```shell +fetch_data contrastive_autoencoder_weights {my_path} +``` + +The datasets that can be automatically fetched and used are available in +`tractolearn.tractoio.dataset_fetch.Dataset`. + +Fetching the [RecoBundlesX data](https://zenodo.org/record/7562635) is also +made available. + ## How to cite If you use this toolkit in a scientific publication or if you want to cite @@ -77,6 +95,9 @@ our previous works, we would appreciate if you considered the following aspects: The corresponding `BibTeX` files are contained in the above links. +If you use the [data](https://zenodo.org/record/7562790) made available by the +authors, please cite the appropriate Zenodo record. + Please reach out to us if you have related questions. ## Patent diff --git a/pyproject.toml b/pyproject.toml index 2b3573f..e279ffd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ test = [ "pytest-cov", "pytest-pep8", "pytest-xdist", + "pytest_console_scripts", ] dev = [ "black == 22.12", @@ -73,6 +74,7 @@ ae_bundle_streamlines = "scripts:ae_bundle_streamlines.main" ae_find_thresholds = "scripts:ae_find_thresholds.main" ae_generate_streamlines = "scripts:ae_generate_streamlines.main" ae_train = "scripts:ae_train.main" +fetch_data = "scripts:fetch_data.main" [options.extras_require] all = [ diff --git a/scripts/fetch_data.py b/scripts/fetch_data.py new file mode 100755 index 0000000..36177f9 --- /dev/null +++ b/scripts/fetch_data.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import argparse +from pathlib import Path + +from tractolearn.tractoio.dataset_fetch import Dataset, retrieve_dataset + + +def _build_arg_parser(): + + parser = argparse.ArgumentParser( + description="Fetch tractolearn dataset", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "dastaset_name", + type=Dataset.argparse, + choices=list(Dataset), + help="Dataset name", + ) + parser.add_argument( + "out_path", + type=Path, + help="Output path", + ) + + return parser + + +def main(): + + # Parse arguments + parser = _build_arg_parser() + args = parser.parse_args() + + _ = retrieve_dataset(Dataset(args.dastaset_name).name, args.out_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/tests/__init__.py b/scripts/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/tests/test_fetch_data.py b/scripts/tests/test_fetch_data.py new file mode 100644 index 0000000..862bff7 --- /dev/null +++ b/scripts/tests/test_fetch_data.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import tempfile +from os import listdir +from os.path import isfile, join + + +def test_help_option(script_runner): + + ret = script_runner.run( + "fetch_data.py", "--help" + ) + assert ret.success + + +def test_execution(script_runner): + + # Test the lightest datasets + with tempfile.TemporaryDirectory() as tmp_dir: + + os.chdir(os.path.expanduser(tmp_dir)) + + ret = script_runner.run( + "fetch_data.py", + "contrastive_autoencoder_weights", + tmp_dir) + + assert ret.success + + files = [f for f in listdir(tmp_dir) if isfile(join(tmp_dir, f))] + assert len(files) == 1 + + with tempfile.TemporaryDirectory() as tmp_dir: + + os.chdir(os.path.expanduser(tmp_dir)) + + ret = script_runner.run( + "fetch_data.py", + "mni2009cnonlinsymm_anat", + tmp_dir) + + assert ret.success + + files = [f for f in listdir(tmp_dir) if isfile(join(tmp_dir, f))] + assert len(files) == 1 diff --git a/tractolearn/tractoio/dataset_fetch.py b/tractolearn/tractoio/dataset_fetch.py new file mode 100644 index 0000000..d8120c9 --- /dev/null +++ b/tractolearn/tractoio/dataset_fetch.py @@ -0,0 +1,539 @@ +# -*- coding: utf-8 -*- + +import contextlib +import enum +import logging +import os +import tarfile +import zipfile +from hashlib import md5 +from os.path import join as pjoin +from shutil import copyfileobj +from urllib.request import urlopen + +from tqdm.auto import tqdm + +logger = logging.getLogger(__name__) + +TRACTOLEARN_DATASETS_URL = "https://zenodo.org/record/" + +key_separator = "," + + +class Dataset(enum.Enum): + """Datasets for tractography learning. + """ + + BUNDLE_LABEL_CONFIG = "bundle_label_config" + CONTRASTIVE_AUTOENCODER_WEIGHTS = "contrastive_ae_weights" + MNI2009CNONLINSYMM_ANAT = "mni2009cnonlinsymm_anat" + GENERATIVE_LOA_CONE_CONFIG = "generative_loa_cone_config" + GENERATIVE_SEED_STRML_RATIO_CONFIG = "generative_seed_streamline_ratio_config" + GENERATIVE_STRML_MAX_COUNT_CONFIG = "generative_streamline_max_count_config" + GENERATIVE_STRML_RQ_COUNT_CONFIG = "generative_streamline_req_count_config" + GENERATIVE_WM_TISSUE_CRITERION_CONFIG = "generative_wm_tisue_criterion_config" + RECOBUNDLESX_ATLAS = "recobundlesx_atlas" + RECOBUNDLESX_CONFIG = "recobundlesx_config" + TRACTOINFERNO_HCP_CONTRASTIVE_THR_CONFIG = "tractoinferno_hcp_contrastive_thr_config" + TRACTOINFERNO_HCP_REF_TRACTOGRAPHY = "tractoinferno_hcp_ref_tractography" + + # Methods for argparse compatibility + def __str__(self): + return self.name.lower() + + def __repr__(self): + return str(self) + + @staticmethod + def argparse(s): + try: + return Dataset[s.upper()] + except KeyError: + return s + + +class FetcherError(Exception): + pass + + +class DatasetError(Exception): + pass + + +def _check_known_dataset(name): + """Raise a DatasetError if the dataset is unknown. + + Parameters + ---------- + name : string + Dataset name. + """ + + if name not in Dataset.__members__.keys(): + raise DatasetError(_unknown_dataset_msg(name)) + + +def _exclude_dataset_use_permission_files(fnames, permission_fname): + """Exclude dataset use permission files from the data filenames. + + Parameters + ---------- + fnames : list + Filenames. + + Returns + ------- + key : string + Key value. + """ + + return [f for f in fnames if permission_fname not in f] + + +def copyfileobj_withprogress(fsrc, fdst, total_length, length=16 * 1024): + + for _ in tqdm(range(0, int(total_length), length), unit=" MB"): + buf = fsrc.read(length) + if not buf: + break + fdst.write(buf) + + +def _already_there_msg(folder): + """Print a message indicating that dataset is already in place.""" + + msg = "Dataset is already in place.\nIf you want to fetch it again, " + msg += f"please first remove the file at issue in folder\n{folder}" + logger.info(msg) + + +def _unknown_dataset_msg(name): + """Build a message indicating that dataset is not known. + + Parameters + ---------- + name : string + Dataset name. + + Returns + ------- + msg : string + Message. + """ + + msg = f"Unknown dataset.\nProvided: {name}; Available: {Dataset.__members__.keys()}" + return msg + + +def _get_file_hash(filename): + """Generate an MD5 hash for the entire file in blocks of 128. + + Parameters + ---------- + filename : str + The path to the file whose MD5 hash is to be generated. + + Returns + ------- + hash256_data : str + The computed MD5 hash from the input file. + """ + + hash_data = md5() + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(128 * hash_data.block_size), b""): + hash_data.update(chunk) + return hash_data.hexdigest() + + +def check_hash(filename, stored_hash=None): + """Check that the hash of the given filename equals the stored one. + + Parameters + ---------- + filename : str + The path to the file whose hash is to be compared. + stored_hash : str, optional + Used to verify the generated hash. + Default: None, checking is skipped. + """ + + if stored_hash is not None: + computed_hash = _get_file_hash(filename) + if stored_hash.lower() != computed_hash: + msg = ( + f"The downloaded file\n{filename}\ndoes not have the expected hash " + f"value of {stored_hash}.\nInstead, the hash value was {computed_hash}.\nThis could " + "mean that something is wrong with the file or that the " + "upstream file has been updated.\nYou can try downloading " + f"file again or updating to the newest version of {__name__.split('.')[0]}" + ) + raise FetcherError(msg) + + +def _get_file_data(fname, url): + + with contextlib.closing(urlopen(url)) as opener: + try: + response_size = opener.headers["content-length"] + except KeyError: + response_size = None + + with open(fname, "wb") as data: + if response_size is None: + copyfileobj(opener, data) + else: + copyfileobj_withprogress(opener, data, response_size) + + +def fetch_data(files, folder, data_size=None): + """Download files to folder and checks their hashes. + + Parameters + ---------- + files : dictionary + For each file in ``files`` the value should be (url, hash). The file + will be downloaded from url if the file does not already exist or if + the file exists but the hash does not match. + folder : str + The directory where to save the file, the directory will be created if + it does not already exist. + data_size : str, optional + A string describing the size of the data (e.g. "91 MB") to be logged to + the screen. Default does not produce any information about data size. + + Raises + ------ + FetcherError + Raises if the hash of the file does not match the expected value. The + downloaded file is not deleted when this error is raised. + """ + + if not os.path.exists(folder): + logger.info(f"Creating new folder\n{folder}") + os.makedirs(folder) + + if data_size is not None: + logger.info(f"Data size is approximately {data_size}") + + all_skip = True + for f in files: + url, _hash = files[f] + fullpath = pjoin(folder, f) + if os.path.exists(fullpath) and ( + _get_file_hash(fullpath) == _hash.lower() + ): + continue + all_skip = False + logger.info(f"Downloading\n{f}\nto\n{folder}") + _get_file_data(fullpath, url) + check_hash(fullpath, _hash) + if all_skip: + _already_there_msg(folder) + else: + logger.info(f"\nFiles successfully downloaded to\n{folder}") + + +def _make_fetcher( + folder, + name, + baseurl, + remote_fnames, + local_fnames, + hash_list=None, + doc="", + data_size=None, + msg=None, + unzip=False, +): + """Create a new fetcher. + + Parameters + ---------- + folder : str + The full path to the folder in which the files would be placed locally. + name : str + The name of the fetcher function. + baseurl : str + The URL from which this fetcher reads files. + remote_fnames : list of strings + The names of the files in the baseurl location. + local_fnames : list of strings + The names of the files to be saved on the local filesystem. + hash_list : list of strings, optional + The hash values of the files. Used to verify the content of the files. + Default: None, skipping checking hash. + doc : str, optional. + Documentation of the fetcher. + data_size : str, optional. + If provided, is sent as a message to the user before downloading + starts. + msg : str, optional + A message to print to screen when fetching takes place. Default (None) + is to print nothing. + unzip : bool, optional + Whether to unzip the file(s) after downloading them. Supports zip, gz, + and tar.gz files. + + Returns + ------- + fetcher : function + A function that, when called, fetches data according to the designated + inputs + + """ + + def fetcher(): + files = {} + for ( + i, + (f, n), + ) in enumerate(zip(remote_fnames, local_fnames)): + files[n] = ( + baseurl + f, + hash_list[i] if hash_list is not None else None, + ) + fetch_data(files, folder, data_size) + + if msg is not None: + logger.info(msg) + if unzip: + for f in local_fnames: + split_ext = os.path.splitext(f) + if split_ext[-1] == ".gz" or split_ext[-1] == ".bz2": + if os.path.splitext(split_ext[0])[-1] == ".tar": + ar = tarfile.open(pjoin(folder, f)) + ar.extractall(path=folder) + ar.close() + else: + raise ValueError("File extension is not recognized") + elif split_ext[-1] == ".zip": + z = zipfile.ZipFile(pjoin(folder, f), "r") + files[f] += (tuple(z.namelist()),) + z.extractall(folder) + z.close() + else: + raise ValueError("File extension is not recognized") + + return files, folder + + fetcher.__name__ = name + fetcher.__doc__ = doc + return fetcher + + +fetch_bundle_label_config = ( + "fetch_bundle_label_config", + TRACTOLEARN_DATASETS_URL + "/7562790/files/", + ["rbx_atlas_v10.json"], + ["rbx_atlas_v10.json"], + ["0edde5be1b3e32d12a5f02d77b46d32b"], + "Bundle labels", + "581B", + "", + False, +) + +fetch_contrastive_ae_weights = ( + "fetch_contrastive_ae_weights", + TRACTOLEARN_DATASETS_URL + "7562790/files/", + ["best_model_contrastive_tractoinferno_hcp.pt"], + ["best_model_contrastive_tractoinferno_hcp.pt"], + ["2181aa950d8110b89f5b4bf7ebbb9aff"], + "Download contrastive-loss trained tractolearn autoencoder weights", + "56.7MB", + "", + False, +) + +fetch_mni2009cnonlinsymm_anat = ( + "fetch_mni2009cnonlinsymm_anat", + TRACTOLEARN_DATASETS_URL + "7562790/files/", + ["mni_masked.nii.gz"], + ["mni_masked.nii.gz"], + ["ea6c119442d23a25033de19b55c607d3"], + "Download MNI ICBM 2009c Nonlinear Symmetric 1×1x1mm template dataset", + "4.9MB", + "", + False, +) + +fetch_generative_loa_cone_config = ( + "fetch_generative_loa_cone_config", + TRACTOLEARN_DATASETS_URL + "/7562790/files/", + ["degree.json"], + ["degree.json"], + ["9b97737ac0f9f362f9936792028de934"], + "Bundle-wise local orientation angle cone in degrees for generative streamline assessment", + "670B", + "", + False, +) + +fetch_generative_seed_streamline_ratio_config = ( + "fetch_generative_seed_streamline_ratio_config", + TRACTOLEARN_DATASETS_URL + "/7562790/files/", + ["ratio.json"], + ["ratio.json"], + ["69d14a2100d9a948f63489767586a939"], + "Bundle-wise (subject | atlas) seed streamline ratio", + "978B", + "", + False, +) + +fetch_generative_streamline_max_count_config = ( + "fetch_generative_streamline_max_count_config", + TRACTOLEARN_DATASETS_URL + "/7562790/files/", + ["max_total_sampling.json"], + ["max_total_sampling.json"], + ["33f46357fd791d2c6f6b7da473fd8bbc"], + "Maximum number of generative bundle-wise streamline count", + "787B", + "", + False, +) + +fetch_generative_streamline_req_count_config = ( + "fetch_generative_streamline_req_count_config", + TRACTOLEARN_DATASETS_URL + "/7562790/files/", + ["number_rejection_sampling.json"], + ["number_rejection_sampling.json"], + ["46acbb64aa3b2ba846c727cb8554566d "], + "Requested number of generative bundle-wise streamline count", + "783MB", + "", + False, +) + +fetch_generative_wm_tisue_criterion_config = ( + "fetch_generative_wm_tisue_criterion_config", + TRACTOLEARN_DATASETS_URL + "/7562790/files/", + ["white_matter_mask.json"], + ["white_matter_mask.json"], + ["79b7605ec13e0b2bcb6a38c15f159381"], + "Bundle-wise WM tissue criterion (WM mask | thresholded FA) for generative streamline assessment", + "748B", + "", + False, +) + +fetch_recobundlesx_atlas = ( + "fetch_recobundlesx_atlas", + TRACTOLEARN_DATASETS_URL + "7562635/files/", + ["atlas.zip"], + ["atlas.zip"], + ["0d2857efa7cfda6f57e5abcad4717c2a"], + "Download RecoBundlesX population average and centroid tractograms", + "159.0MB", + "", + True, +) + +fetch_recobundlesx_config = ( + "fetch_recobundlesx_config", + TRACTOLEARN_DATASETS_URL + "7562635/files/", + ["config.zip"], + ["config.zip"], + ["439e2488597243455872ec3dcb50eda7"], + "Download RecoBundlesX clustering parameter values", + "3.6KB", + "", + True, +) + +fetch_tractoinferno_hcp_contrastive_threshold_config = ( + "fetch_tractoinferno_hcp_contrastive_threshold_config", + TRACTOLEARN_DATASETS_URL + "/7562790/files/", + ["thresholds_contrastive_tractoinferno_hcp.json"], + ["thresholds_contrastive_tractoinferno_hcp.json"], + ["beecb87d73d53fa4f4ed6714af420cfd"], + "Bundle-wise bundling latent space distance threshold values", + "824B", + "", + False, +) + +fetch_tractoinferno_hcp_ref_tractography = ( + "fetch_tractoinferno_hcp_ref_tractography", + TRACTOLEARN_DATASETS_URL + "/7562790/files/", + ["data_tractoinferno_hcp_qbx.hdf5"], + ["data_tractoinferno_hcp_qbx.hdf5"], + ["4803d36278d1575a40e9048a7380aa10"], + "Download TractoInferno-HCP reference tractography dataset", + "74.0GB", + "", + False, +) + + +def retrieve_dataset(name, path): + """Retrieve the given dataset to the provided path. + + Parameters + ---------- + name : string + Dataset name. + path : string + Destination path. + + Returns + ------- + fnames : string or list + Filenames for dataset. + """ + + logger.info(f"\nDataset: {name}") + + if name == Dataset.BUNDLE_LABEL_CONFIG.name: + params = fetch_bundle_label_config + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + elif name == Dataset.CONTRASTIVE_AUTOENCODER_WEIGHTS.name: + params = fetch_contrastive_ae_weights + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + elif name == Dataset.MNI2009CNONLINSYMM_ANAT.name: + params = fetch_mni2009cnonlinsymm_anat + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + elif name == Dataset.GENERATIVE_LOA_CONE_CONFIG.name: + params = fetch_generative_loa_cone_config + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + elif name == Dataset.GENERATIVE_SEED_STRML_RATIO_CONFIG.name: + params = fetch_generative_seed_streamline_ratio_config + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + elif name == Dataset.GENERATIVE_STRML_MAX_COUNT_CONFIG.name: + params = fetch_generative_streamline_max_count_config + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + elif name == Dataset.GENERATIVE_STRML_RQ_COUNT_CONFIG.name: + params = fetch_generative_streamline_max_count_config + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + elif name == Dataset.GENERATIVE_WM_TISSUE_CRITERION_CONFIG.name: + params = fetch_generative_wm_tisue_criterion_config + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + elif name == Dataset.RECOBUNDLESX_ATLAS.name: + params = fetch_recobundlesx_atlas + files, folder = _make_fetcher(path, *params)() + fnames = files["atlas.zip"][2] + return sorted([pjoin(folder, f) for f in fnames if os.path.isfile(pjoin(folder, f))]) + elif name == Dataset.RECOBUNDLESX_CONFIG.name: + params = fetch_recobundlesx_config + files, folder = _make_fetcher(path, *params)() + fnames = files["config.zip"][2] + return sorted([pjoin(folder, f) for f in fnames]) + elif name == Dataset.TRACTOINFERNO_HCP_CONTRASTIVE_THR_CONFIG.name: + params = fetch_tractoinferno_hcp_contrastive_threshold_config + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + elif name == Dataset.TRACTOINFERNO_HCP_REF_TRACTOGRAPHY.name: + params = fetch_tractoinferno_hcp_ref_tractography + files, folder = _make_fetcher(path, *params)() + return pjoin(folder, list(files.keys())[0]) + else: + raise DatasetError(_unknown_dataset_msg(name)) diff --git a/tractolearn/tractoio/tests/__init__.py b/tractolearn/tractoio/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tractolearn/tractoio/tests/test_dataset_fetch.py b/tractolearn/tractoio/tests/test_dataset_fetch.py new file mode 100644 index 0000000..526a0e5 --- /dev/null +++ b/tractolearn/tractoio/tests/test_dataset_fetch.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os + +from tractolearn.tractoio.dataset_fetch import Dataset, retrieve_dataset + + +def test_retrieve_dataset(tmp_path): + + dataset_keys = list(Dataset.__members__.keys()) + # Exclude attempting to download the HDF% file due to its size + dataset_keys.remove(Dataset.TRACTOINFERNO_HCP_REF_TRACTOGRAPHY.name) + + for name in dataset_keys: + files = retrieve_dataset(name, tmp_path) + + if isinstance(files, str): + assert os.path.isfile(files) + elif isinstance(files, list): + for elem in files: + assert os.path.isfile(elem) + else: + raise TypeError("Unexpected type found.")