Skip to content

Commit

Permalink
ENH: Improve importing optional packages and modules
Browse files Browse the repository at this point in the history
Improve importing optional packages and modules:
- Add helper classes and utils to take care of trying to import optional
  packages, and to provide meaningful information when such packages
  cannot be imported.
- Add the corresponding tests.

Prefer using all lowercases for the flags that contain the
package/module import success value.

Move the `whitematteranalysis/utils.py` module to the
`whitematteranalysis/utils/utils.py` module, and remove the `utils`
module import from `whitematteranalysis/__init__.py`  to avoid module
clashes.
  • Loading branch information
jhlegarreta committed Nov 30, 2023
1 parent afcb30b commit 09f78bc
Show file tree
Hide file tree
Showing 16 changed files with 256 additions and 72 deletions.
5 changes: 3 additions & 2 deletions bin/wm_assess_cluster_location_by_hemisphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import vtk

import whitematteranalysis as wma
from whitematteranalysis.utils.utils import hemisphere_loc_name_typo_warn_msg


def _build_arg_parser():
Expand Down Expand Up @@ -75,7 +76,7 @@ def write_mask_location_to_vtk(inpd, mask_location):
array = inpointdata.GetArray(idx)
if array.GetName() == 'HemisphereLocataion':
warnings.warn(
wma.utils.hemisphere_loc_name_typo_warn_msg,
hemisphere_loc_name_typo_warn_msg,
PendingDeprecationWarning)
print(' -- HemisphereLocataion is in the input data: skip updating the vtk file.')
return inpd
Expand Down Expand Up @@ -122,7 +123,7 @@ def _read_location(_inpd):
array = inpointdata.GetArray(idx)
if array.GetName() == 'HemisphereLocataion':
warnings.warn(
wma.utils.hemisphere_loc_name_typo_warn_msg,
hemisphere_loc_name_typo_warn_msg,
PendingDeprecationWarning)
flag_location = _read_location(inpd)
break
Expand Down
19 changes: 10 additions & 9 deletions bin/wm_quality_control_after_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@
import argparse
import glob
import os
import warnings

import numpy as np

import whitematteranalysis as wma
from whitematteranalysis.utils.opt_pckg import optional_package

HAVE_PLT = 1
try:
import matplotlib
matplotlib, have_mpl, _ = optional_package("matplotlib")
plt, _, _ = optional_package("matplotlib.pyplot")

if have_mpl:
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
import matplotlib.pyplot as plt
except:
print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n")
HAVE_PLT = 0
matplotlib.use("Agg")
else:
warnings.warn(matplotlib._msg)
warnings.warn("Cannot plot quality control data.")


def _build_arg_parser():
Expand Down Expand Up @@ -109,7 +110,7 @@ def main():
print(cidx + 1,'\t', subjects_per_cluster[cidx],'\t', percent_subjects_per_cluster[cidx] * 100.0, file=clusters_qc_file)
clusters_qc_file.close()

if HAVE_PLT:
if have_mpl:
print(f"<{os.path.basename(__file__)}> Saving subjects per cluster histogram.")
fig, ax = plt.subplots()
counts = np.zeros(num_of_subjects+1)
Expand Down
15 changes: 8 additions & 7 deletions bin/wm_quality_control_tract_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
import argparse
import os
import time
import warnings

import numpy as np
import vtk

import whitematteranalysis as wma
from whitematteranalysis.utils.opt_pckg import optional_package

HAVE_PLT = 1
matplotlib, have_mpl, _ = optional_package("matplotlib")
plt, _, _ = optional_package("matplotlib.pyplot")

try:
import matplotlib.pyplot as plt
except:
print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n")
HAVE_PLT = 0
if not have_mpl:
warnings.warn(matplotlib._msg)
warnings.warn("Cannot plot quality control data.")


def _build_arg_parser():
Expand Down Expand Up @@ -69,7 +70,7 @@ def main():

number_of_subjects = len(input_polydatas)

if HAVE_PLT:
if have_mpl:
plt.figure(1)

# Loop over subjects and check each
Expand Down
20 changes: 11 additions & 9 deletions bin/wm_quality_control_tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
import os
import sys
import time
import warnings

import numpy as np
import vtk

import whitematteranalysis as wma
from whitematteranalysis.utils.opt_pckg import optional_package

HAVE_PLT = 1
try:
import matplotlib.pyplot as plt
except:
print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n")
HAVE_PLT = 0
matplotlib, have_mpl, _ = optional_package("matplotlib")
plt, _, _ = optional_package("matplotlib.pyplot")

if not have_mpl:
warnings.warn(matplotlib._msg)
warnings.warn("Cannot plot quality control data.")


def _build_arg_parser():
Expand Down Expand Up @@ -175,7 +177,7 @@ def main():
spatial_qc_file.write(outstr)
spatial_qc_file.close()

if HAVE_PLT:
if have_mpl:
plt.figure(1)

# Loop over subjects and check each
Expand Down Expand Up @@ -272,7 +274,7 @@ def main():
spatial_qc_file.write(outstr)

# Save the subject's fiber lengths
if HAVE_PLT:
if have_mpl:
plt.figure(1)
if lengths.size > 1:
plt.hist(lengths, bins=100, histtype='step', label=subject_id)
Expand Down Expand Up @@ -329,7 +331,7 @@ def main():
del pd3
subject_idx += 1

if HAVE_PLT:
if have_mpl:
plt.figure(1)
plt.title('Histogram of fiber lengths for all subjects')
plt.xlabel('fiber length (mm)')
Expand Down
5 changes: 3 additions & 2 deletions bin/wm_separate_clusters_by_hemisphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import vtk

import whitematteranalysis as wma
from whitematteranalysis.utils.utils import hemisphere_loc_name_typo_warn_msg


def _build_arg_parser():
Expand Down Expand Up @@ -86,7 +87,7 @@ def write_mask_location_to_vtk(inpd, mask_location):
array = inpointdata.GetArray(idx)
if array.GetName() == 'HemisphereLocataion':
warnings.warn(
wma.utils.hemisphere_loc_name_typo_warn_msg,
hemisphere_loc_name_typo_warn_msg,
PendingDeprecationWarning)
print(' -- HemisphereLocataion is in the input data: skip updating the vtk file.')
return inpd
Expand Down Expand Up @@ -133,7 +134,7 @@ def _read_location(_inpd):
array = inpointdata.GetArray(idx)
if array.GetName() == 'HemisphereLocataion':
warnings.warn(
wma.utils.hemisphere_loc_name_typo_warn_msg,
hemisphere_loc_name_typo_warn_msg,
PendingDeprecationWarning)
flag_location = _read_location(inpd)
break
Expand Down
2 changes: 1 addition & 1 deletion whitematteranalysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
io, laterality, mrml, register_two_subjects,
register_two_subjects_nonrigid,
register_two_subjects_nonrigid_bsplines, relative_distance,
render, similarity, tract_measurement, utils)
render, similarity, tract_measurement)
22 changes: 12 additions & 10 deletions whitematteranalysis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,21 @@

from pprint import pprint

from whitematteranalysis.utils.opt_pckg import optional_package

from . import fibers, filter, io, mrml, render, similarity

HAVE_PLT = 1
try:
import matplotlib
matplotlib, have_mpl, _ = optional_package("matplotlib")
plt, _, _ = optional_package("matplotlib.pyplot")

if have_mpl:
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
import matplotlib.pyplot as plt
except:
print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n")
HAVE_PLT = 0
matplotlib.use("Agg")
else:
warnings.warn(matplotlib._msg)
warnings.warn("Cannot plot quality control data.")


# This did not work better. Leave here for future testing if of interest
if 0:
try:
Expand Down Expand Up @@ -986,7 +988,7 @@ def output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_f

clusters_qc_file.close()

if HAVE_PLT:
if have_mpl:
print(f"<{os.path.basename(__file__)}> Saving subjects per cluster histogram.")
fig, ax = plt.subplots()
counts = np.zeros(number_of_subjects+1)
Expand Down
27 changes: 14 additions & 13 deletions whitematteranalysis/congeal_multisubject.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@ class MultiSubjectRegistration

import os
import time
import warnings

import numpy as np
import vtk
from joblib import Parallel, delayed

HAVE_PLT = 1
try:
import matplotlib
import whitematteranalysis as wma
from whitematteranalysis.utils.opt_pckg import optional_package

# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')
import matplotlib.pyplot as plt
except:
print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot objectives.\n")
HAVE_PLT = 0
matplotlib, have_mpl, _ = optional_package("matplotlib")
plt, _, _ = optional_package("matplotlib.pyplot")

import whitematteranalysis as wma
if have_mpl:
# Force matplotlib to not use any Xwindows backend.
matplotlib.use("Agg")
else:
warnings.warn(matplotlib._msg)
warnings.warn("Cannot plot objectives.")


class MultiSubjectRegistration:
Expand Down Expand Up @@ -304,7 +305,7 @@ def iterate(self):
functions_per_subject = list()
objective_changes_per_subject = list()
decreases = list()
if HAVE_PLT:
if have_mpl:
plt.close('all')
plt.figure(0)
plt.title('Iteration '+str(self.total_iterations)+' Objective Values for All Subjects')
Expand All @@ -326,7 +327,7 @@ def iterate(self):
objective_total_after += objectives[0]
objective_changes_per_subject.append(diff)
sidx += 1
if HAVE_PLT:
if have_mpl:
plt.figure(0)
plt.plot(objectives, 'o-', label=sidx)

Expand All @@ -342,7 +343,7 @@ def iterate(self):
print("Iteration:", self.total_iterations, "TOTAL objective change:", total_change)
print("Iteration:", self.total_iterations, "PERCENT objective change:", percent_change)

if HAVE_PLT:
if have_mpl:
plt.figure(0)
if self.mode == "Nonrigid":
fname_fig_base = "iteration_%05d_sigma_%03d_grid_%03d" % (self.total_iterations, self.sigma, self.nonrigid_grid_resolution)
Expand Down
26 changes: 15 additions & 11 deletions whitematteranalysis/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,24 @@
"""

import os
import warnings

import numpy as np
import vtk

try:
from joblib import Parallel, delayed
USE_PARALLEL = 1
except ImportError:
USE_PARALLEL = 0
print(f"<{os.path.basename(__file__)}> Failed to import joblib, cannot multiprocess.")
print(f"<{os.path.basename(__file__)}> Please install joblib for this functionality.")
from whitematteranalysis.utils.opt_pckg import optional_package

from . import fibers, similarity

joblib, have_joblib, _ = optional_package("joblib")
Parallel, _, _ = optional_package("joblib.Parallel")
delayed, _, _ = optional_package("joblib.delayed")

if not have_joblib:
warnings.warn(joblib._msg)
warnings.warn("Cannot multiprocess.")


verbose = 0


Expand Down Expand Up @@ -604,7 +608,7 @@ def remove_outliers(inpd, min_fiber_distance, n_jobs=0, distance_method ='Mean')
min_fiber_distance = min_fiber_distance * min_fiber_distance

# pairwise distance matrix
if USE_PARALLEL and n_jobs > 0:
if have_joblib and n_jobs > 0:
distances = Parallel(n_jobs=n_jobs, verbose=1)(
delayed(similarity.fiber_distance)(
fiber_array.get_fiber(lidx),
Expand Down Expand Up @@ -678,7 +682,7 @@ def smooth(inpd, fiber_distance_sigma = 25, points_per_fiber=30, n_jobs=2, upper
print(f"<{os.path.basename(__file__)}> Computing pairwise distances...")

# pairwise distance matrix
if USE_PARALLEL:
if have_joblib:
distances = Parallel(n_jobs=n_jobs, verbose=1)(
delayed(similarity.fiber_distance)(
current_fiber_array.get_fiber(lidx),
Expand Down Expand Up @@ -802,7 +806,7 @@ def anisotropic_smooth(inpd, fiber_distance_threshold, points_per_fiber=30, n_jo
done[np.nonzero(np.array(curr_count) >= cluster_max)] = 1

# pairwise distance matrix
if USE_PARALLEL:
if have_joblib:
distances = Parallel(n_jobs=n_jobs, verbose=1)(
delayed(similarity.fiber_distance)(
current_fiber_array.get_fiber(lidx),
Expand Down Expand Up @@ -967,7 +971,7 @@ def laplacian_of_gaussian(inpd, fiber_distance_sigma = 25, points_per_fiber=30,
fiber_indices = list(range(0, fiber_array.number_of_fibers))

# pairwise distance matrix
if USE_PARALLEL:
if have_joblib:
distances = Parallel(n_jobs=n_jobs, verbose=1)(
delayed(similarity.fiber_distance)(
fiber_array.get_fiber(lidx),
Expand Down
19 changes: 11 additions & 8 deletions whitematteranalysis/laterality.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,25 @@ class ComputeWhiteMatterLaterality
"""

import os
import warnings

import numpy as np
import vtk

try:
from joblib import Parallel, delayed
USE_PARALLEL = 1
except ImportError:
USE_PARALLEL = 0
print(f"<{os.path.basename(__file__)}> Failed to import joblib, cannot multiprocess.")
print(f"<{os.path.basename(__file__)}> Please install joblib for this functionality.")
from whitematteranalysis.utils.opt_pckg import optional_package

from . import filter, similarity
from .fibers import FiberArray
from .io import LateralityResults

joblib, have_joblib, _ = optional_package("joblib")
Parallel, _, _ = optional_package("joblib.Parallel")
delayed, _, _ = optional_package("joblib.delayed")

if not have_joblib:
warnings.warn(joblib._msg)
warnings.warn("Cannot multiprocess.")


def compute_laterality_index(left, right, idx=None):
''' Compute laterality index from left and right hemisphere quantities.'''
Expand Down Expand Up @@ -171,7 +174,7 @@ def compute(self, input_vtk_polydata):
print(f"<{os.path.basename(__file__)}> Starting to compute laterality indices")

# run the computation, either in parallel or not
if (USE_PARALLEL & (self.parallel_jobs > 1)):
if (have_joblib & (self.parallel_jobs > 1)):
if self.verbose:
print(f"<{os.path.basename(__file__)}> Starting parallel code. Processes:", \
self.parallel_jobs)
Expand Down
Empty file.
Loading

0 comments on commit 09f78bc

Please sign in to comment.