Skip to content

Commit

Permalink
- Added Docker support
Browse files Browse the repository at this point in the history
- Fixes to weights.py
  • Loading branch information
m-lyon committed Apr 5, 2022
1 parent ae9496b commit 6c1d400
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 12 deletions.
5 changes: 5 additions & 0 deletions Docker_CPU
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# syntax=docker/dockerfile:latest

FROM tensorflow/tensorflow:2.7.1
RUN pip install dmri-rcnn
RUN dmri_rcnn_download_all_weights.py
5 changes: 5 additions & 0 deletions Docker_GPU
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# syntax=docker/dockerfile:latest

FROM nvcr.io/nvidia/tensorflow:22.03-tf2-py3
RUN pip install dmri-rcnn
RUN dmri_rcnn_download_all_weights.py
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@

This project enhances the angular resolution of dMRI data through the use of a Recurrent CNN.

## Table of contents
* [Installation](#installation)
* [Inference](#inference)
* [Training](#training)
* [Docker](#docker)

## Installation
`dMRI-RCNN` can be installed by via pip:
```bash
pip install dmri-rcnn
```

### Requirements
`dMRI-RCNN` uses [TensorFlow](https://www.tensorflow.org/) as the deep learning architecture.
`dMRI-RCNN` uses [TensorFlow](https://www.tensorflow.org/) as the deep learning architecture. To enable [GPU usage within TensorFlow](https://www.tensorflow.org/install/gpu), you should ensure the appropriate prerequisites are installed.

Listed below are the requirements for this package.
- `tensorflow>=2.6.0`
Expand Down Expand Up @@ -145,7 +151,21 @@ validation_data = processor.load_data(['/path/to/val_data0.tfrecord'], validatio
model.fit(train_data, epochs=10, validation_data=validation_data)
```

## Docker
You can also use `dMRI-RCNN` directly via [Docker](https://www.docker.com/). Both a CPU and GPU version of the project are available.

### CPU
To use `dMRI-RCNN` with the CPU only, use:
```bash
sudo docker run -v /absolute/path/to/my/data/directory:/data -it -t mlyon93/dmri-rcnn-cpu:latest
```

### GPU
To use `dMRI-RCNN` with the GPU, first ensure the [appropriate NVIDIA prerequisites](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker) have been installed. Then use:
```bash
sudo docker run --gpus all -v /absolute/path/to/my/data/directory:/data -it -t mlyon93/dmri-rcnn-gpu:latest
```

## Roadmap
Future Additions & Improvements:
- Plot functionality
- Docker support.
14 changes: 14 additions & 0 deletions dmri_rcnn/bin/dmri_rcnn_download_all_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env python3
'''Script to Download all dMRI RCNN weights'''

from dmri_rcnn.core.weights import get_weights


if __name__ == '__main__':
for model_dim in (1, 3):
for shell in (1000, 2000, 3000, 'all'):
for q_in in (6, 10, 30):
try:
get_weights(model_dim, shell, q_in)
except AttributeError:
pass
2 changes: 1 addition & 1 deletion dmri_rcnn/bin/run_dmri_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main(args):
type=int,
choices=[1, 3],
default=3,
help='Model dimensionality, choose either 1 or 3.',
help='Model dimensionality, choose either 1 or 3. Default: 3.',
)
parser.add_argument(
'-c',
Expand Down
17 changes: 10 additions & 7 deletions dmri_rcnn/core/weights.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
'''Pretrained weights functions'''

from typing import Union
import os

from urllib.request import urlretrieve
Expand Down Expand Up @@ -95,12 +96,13 @@ def download_url(url, output_path):
urlretrieve(url, filename=output_path, reporthook=pbar.update_to)


def get_weights(model_dim: int, shell: int, q_in: int, combined: bool = False) -> str:
def get_weights(model_dim: int, shell: Union[int, str], q_in: int) -> str:
'''Gets weights given model parameters, will download if not present.
Args:
model_dim: Model dimensionality, either 1 or 3
shell: dMRI shell
shell: dMRI shell, either provide int value or "all" str
to get model weights for combined model
q_in: Number of input q-space samples
combined: Return combined model if available
Expand All @@ -109,25 +111,26 @@ def get_weights(model_dim: int, shell: int, q_in: int, combined: bool = False) -
Will raise a RuntimeError if not found.
'''
try:
if combined:
if shell == 'all':
weight_dir = os.path.join(LOCAL_DIR, f'{model_dim}D_RCNN', f'{q_in}in')
weight_urls = WEIGHT_URLS[f'{model_dim}D'][shell]['all']
else:
weight_dir = os.path.join(
LOCAL_DIR, f'{model_dim}D_RCNN', f'b{shell}_{q_in}in'
)
weight_urls = WEIGHT_URLS[f'{model_dim}D'][shell][q_in]
weight_urls = WEIGHT_URLS[f'{model_dim}D'][shell][q_in]
except KeyError:
raise AttributeError(
'Weights in given configuration not found: '
+ f'{model_dim = }, {shell = }, {q_in = }, {combined = }'
+ f'{model_dim = }, {shell = }, {q_in = }'
) from None

if not os.path.exists(weight_dir):
os.makedirs(weight_dir)

for weight in weight_urls:
if not os.path.exists(os.path.join(weight_dir, weight.fname)):
if os.path.exists(os.path.join(weight_dir, weight.fname)):
print('Model weights already present.')
else:
download_url(weight.url, os.path.join(weight_dir, weight.fname))

return os.path.join(weight_dir, 'weights')
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
'tqdm',
]

version = '0.2.1'
version = '0.2.2'
this_dir = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(this_dir, 'README.md'), encoding='utf-8') as f:
long_description = f.read()
Expand All @@ -32,7 +32,7 @@
license='MIT License',
packages=find_namespace_packages(),
install_requires=install_deps,
scripts=['dmri_rcnn/bin/run_dmri_rcnn.py'],
scripts=['dmri_rcnn/bin/run_dmri_rcnn.py', 'dmri_rcnn/bin/dmri_rcnn_download_all_weights.py'],
classifiers=[
'Programming Language :: Python',
'Operating System :: Unix',
Expand Down

0 comments on commit 6c1d400

Please sign in to comment.