Skip to content

Commit

Permalink
Merge pull request #26 from jhlegarreta/AddOSFDataFetcher
Browse files Browse the repository at this point in the history
ENH: Add data fetcher
  • Loading branch information
jhlegarreta authored Feb 6, 2023
2 parents 08e047a + c25408a commit 6a9d2c2
Show file tree
Hide file tree
Showing 9 changed files with 686 additions and 12 deletions.
24 changes: 12 additions & 12 deletions .github/workflows/test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,20 @@ jobs:
run: |
# tox --sitepackages
python -c 'import tractolearn'
# coverage run --source tractolearn -m pytest tractolearn -o junit_family=xunit2 -v --doctest-modules --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml
coverage run --source tractolearn -m pytest tractolearn -o junit_family=xunit2 -v --doctest-modules --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml
#- name: Upload pytest test results
# uses: actions/upload-artifact@master
# with:
# name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}
# path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml
# # Use always() to always run this step to publish test results when there are test failures
# if: always()
- name: Upload pytest test results
uses: actions/upload-artifact@master
with:
name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}
path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: always()

#- name: Statistics
# if: success()
# run: |
# coverage report
- name: Statistics
if: success()
run: |
coverage report
- name: Package Setup
# - name: Run tests with tox
Expand Down
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

[![test, package](https://github.com/scil-vital/tractolearn/actions/workflows/test_package.yml/badge.svg?branch=main)](https://github.com/scil-vital/tractolearn/actions/workflows/test_package.yml?query=branch%3Amain)
[![documentation](https://readthedocs.org/projects/tractolearn/badge/?version=latest)](https://tractolearn.readthedocs.io/en/latest/?badge=latest)
[![DOI tractolearn](https://zenodo.org/badge/DOI/10.5281/zenodo.7562790.svg)](https://doi.org/10.5281/zenodo.7562790)
[![DOI RBX](https://zenodo.org/badge/DOI/10.5281/zenodo.7562635.svg)](https://doi.org/10.5281/zenodo.7562635)

Tractography learning.

Expand Down Expand Up @@ -53,6 +55,22 @@ training pipeline with the following command:
ae_train.py train_config.yaml -vv
```

## Data

To automatically fetch or use the [tractolearn data](https://zenodo.org/record/7562790)
provided, you can use the `retrieve_dataset` method located in the
`tractolearn.tractoio.dataset_fetch` module, or the `dataset_fetch` script,
e.g.:
```shell
fetch_data contrastive_autoencoder_weights {my_path}
```

The datasets that can be automatically fetched and used are available in
`tractolearn.tractoio.dataset_fetch.Dataset`.

Fetching the [RecoBundlesX data](https://zenodo.org/record/7562635) is also
made available.

## How to cite

If you use this toolkit in a scientific publication or if you want to cite
Expand All @@ -77,6 +95,9 @@ our previous works, we would appreciate if you considered the following aspects:

The corresponding `BibTeX` files are contained in the above links.

If you use the [data](https://zenodo.org/record/7562790) made available by the
authors, please cite the appropriate Zenodo record.

Please reach out to us if you have related questions.

## Patent
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ test = [
"pytest-cov",
"pytest-pep8",
"pytest-xdist",
"pytest_console_scripts",
]
dev = [
"black == 22.12",
Expand All @@ -73,6 +74,7 @@ ae_bundle_streamlines = "scripts:ae_bundle_streamlines.main"
ae_find_thresholds = "scripts:ae_find_thresholds.main"
ae_generate_streamlines = "scripts:ae_generate_streamlines.main"
ae_train = "scripts:ae_train.main"
fetch_data = "scripts:fetch_data.main"

[options.extras_require]
all = [
Expand Down
41 changes: 41 additions & 0 deletions scripts/fetch_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
from pathlib import Path

from tractolearn.tractoio.dataset_fetch import Dataset, retrieve_dataset


def _build_arg_parser():

parser = argparse.ArgumentParser(
description="Fetch tractolearn dataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dastaset_name",
type=Dataset.argparse,
choices=list(Dataset),
help="Dataset name",
)
parser.add_argument(
"out_path",
type=Path,
help="Output path",
)

return parser


def main():

# Parse arguments
parser = _build_arg_parser()
args = parser.parse_args()

_ = retrieve_dataset(Dataset(args.dastaset_name).name, args.out_path)


if __name__ == "__main__":
main()
Empty file added scripts/tests/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions scripts/tests/test_fetch_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import tempfile
from os import listdir
from os.path import isfile, join


def test_help_option(script_runner):

ret = script_runner.run(
"fetch_data.py", "--help"
)
assert ret.success


def test_execution(script_runner):

# Test the lightest datasets
with tempfile.TemporaryDirectory() as tmp_dir:

os.chdir(os.path.expanduser(tmp_dir))

ret = script_runner.run(
"fetch_data.py",
"contrastive_autoencoder_weights",
tmp_dir)

assert ret.success

files = [f for f in listdir(tmp_dir) if isfile(join(tmp_dir, f))]
assert len(files) == 1

with tempfile.TemporaryDirectory() as tmp_dir:

os.chdir(os.path.expanduser(tmp_dir))

ret = script_runner.run(
"fetch_data.py",
"mni2009cnonlinsymm_anat",
tmp_dir)

assert ret.success

files = [f for f in listdir(tmp_dir) if isfile(join(tmp_dir, f))]
assert len(files) == 1
Loading

0 comments on commit 6a9d2c2

Please sign in to comment.