Skip to content

Commit

Permalink
Merge pull request #217 from EmmaRenauld/follow_official_scilpy
Browse files Browse the repository at this point in the history
Follow official scilpy
  • Loading branch information
EmmaRenauld authored Nov 16, 2023
2 parents 1bb5498 + 76b059e commit f470467
Show file tree
Hide file tree
Showing 18 changed files with 15 additions and 48 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
run: |
pip install --upgrade pip
pip install pytest
pip install -r requirements_github.txt
pip install -e .
- name: Tests
Expand Down
1 change: 0 additions & 1 deletion dwi_ml/data/processing/space/neighborhood.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import itertools
from typing import Union, List

import numpy as np
import torch
Expand Down
3 changes: 2 additions & 1 deletion dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from dipy.io.stateful_tractogram import StatefulTractogram
from nibabel.streamlines.tractogram import (PerArrayDict, PerArraySequenceDict)
import numpy as np
from scilpy.tractograms.streamline_operations import resample_streamlines_step_size

from scilpy.tracking.tools import resample_streamlines_step_size
from scilpy.utils.streamlines import compress_sft


Expand Down
3 changes: 1 addition & 2 deletions dwi_ml/data/processing/streamlines/post_processing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-
import logging
from typing import List

import numpy as np
import torch

from scilpy.tractograms.uncompress import uncompress
from scilpy.tractanalysis.tools import \
extract_longest_segments_from_profile as segmenting_func
from scilpy.tractanalysis.uncompress import uncompress

# We could try using nan instead of zeros for non-existing previous dirs...
DEFAULT_UNEXISTING_VAL = torch.zeros((1, 3), dtype=torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/models/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from typing import Tuple, Union, List
from typing import Tuple, List

import numpy as np
import torch
Expand Down
4 changes: 2 additions & 2 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
compute_directions, normalize_directions, compute_n_previous_dirs
from dwi_ml.data.processing.streamlines.sos_eos_management import \
convert_dirs_to_class
from dwi_ml.models.embeddings import NoEmbedding, keys_to_embeddings
from dwi_ml.models.embeddings import NoEmbedding
from dwi_ml.models.main_models import (
ModelWithPreviousDirections, ModelWithDirectionGetter,
ModelWithNeighborhood, MainModelOneInput, ModelOneInputWithEmbedding)
ModelWithNeighborhood, ModelOneInputWithEmbedding)
from dwi_ml.models.stacked_rnn import StackedRNN

logger = logging.getLogger('model_logger') # Same logger as Super.
Expand Down
1 change: 0 additions & 1 deletion dwi_ml/models/projects/learn2track_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
import argparse

from dwi_ml.models.embeddings import keys_to_embeddings
from dwi_ml.models.projects.learn2track_model import Learn2TrackModel


Expand Down
2 changes: 0 additions & 2 deletions dwi_ml/models/projects/transformers_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# -*- coding: utf-8 -*-
from dwi_ml.models.embeddings import keys_to_embeddings
from dwi_ml.models.positional_encoding import (
keys_to_positional_encodings)
from dwi_ml.models.projects.transformer_models import (
AbstractTransformerModel)
from dwi_ml.models.utils.direction_getters import check_args_direction_getter

sphere_choices = ['symmetric362', 'symmetric642', 'symmetric724',
'repulsion724', 'repulsion100', 'repulsion200']
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/testing/projects/transformer_visualisation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import add_reference_arg, add_overwrite_arg, add_bbox_arg
from scilpy.tractograms.streamline_operations import resample_streamlines_step_size
from scilpy.tracking.tools import resample_streamlines_step_size
from scilpy.utils.streamlines import compress_sft

from dwi_ml.io_utils import add_logging_arg
Expand Down
2 changes: 0 additions & 2 deletions dwi_ml/training/batch_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from typing import List, Tuple, Iterator, Union

import numpy as np
import torch
import torch.multiprocessing
from torch.utils.data import Sampler

from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
import shutil
from typing import Union, List, Tuple
from typing import Union, List

from comet_ml import (Experiment as CometExperiment, ExistingExperiment)
import numpy as np
Expand Down
2 changes: 0 additions & 2 deletions dwi_ml/training/utils/batch_samplers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
import argparse
import logging

from dwi_ml.experiment_utils.prints import format_dict_to_str
from dwi_ml.experiment_utils.timer import Timer
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler

Expand Down
4 changes: 1 addition & 3 deletions dwi_ml/training/with_generation/batch_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# -*- coding: utf-8 -*-
from typing import List, Dict

import torch
from typing import Dict

from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput

Expand Down
1 change: 0 additions & 1 deletion dwi_ml/unit_tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os

import h5py
import numpy as np
import torch

from dwi_ml.data.dataset.multi_subject_containers import \
Expand Down
3 changes: 1 addition & 2 deletions dwi_ml/unit_tests/test_submethods_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import torch

from dwi_ml.data.processing.space.neighborhood import \
prepare_neighborhood_vectors, extend_coordinates_with_neighborhood, \
unflatten_neighborhood
prepare_neighborhood_vectors, unflatten_neighborhood
from dwi_ml.data.processing.volume.interpolation import \
interpolate_volume_in_neighborhood

Expand Down
15 changes: 4 additions & 11 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# Supported for python 3.10
# Should work for python > 3.8.

# ----------
# Scilpy and dipy must be installed manually first.
# (or run pip install -r requirements_github.txt)
# Not adding here to let you use your favorite version.
# ----------

# Scilpy and comet_ml both require requests. In comet: >=2.18.*,
# which installs a version >2.28. Adding request version explicitely.
requests==2.28.*
bertviz~=1.4.0 # For transformer's visu
torch==1.13.*
tqdm==4.64.*
Expand All @@ -16,7 +13,7 @@ jupyterlab>=3.6.2 # For transformer's visu
IProgress>=0.4 # For jupyter with tdqm
nested_lookup==0.2.25
nose==1.3.*

scilpy==1.5.post2

## Necessary but should be installed with scilpy (Last check: 09/2023):
future==0.18.*
Expand All @@ -27,10 +24,6 @@ numpy==1.23.*
scipy==1.9.*


# Scilpy requires requests==2.28.*, but comet_ml requires requests>=2.18.*,
# which installs a version >2.28. Adding request version explicitely.
requests==2.28.*


# --------------- Notes to developers
# If we upgrade torch, verify if code copied in
Expand Down
10 changes: 0 additions & 10 deletions requirements_github.txt

This file was deleted.

5 changes: 1 addition & 4 deletions scripts_python/tests/test_all_steps_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
import pytest
import tempfile

import torch

from dwi_ml.unit_tests.utils.expected_values import \
(TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_STREAMLINE_GROUPS,
TEST_EXPECTED_SUBJ_NAMES)
(TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_STREAMLINE_GROUPS)
from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data

data_dir = fetch_testing_data()
Expand Down

0 comments on commit f470467

Please sign in to comment.