From 09f78bca14ee21e6e71cb29c05d76527cd6cabe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 21 Oct 2023 17:33:55 -0400 Subject: [PATCH] ENH: Improve importing optional packages and modules Improve importing optional packages and modules: - Add helper classes and utils to take care of trying to import optional packages, and to provide meaningful information when such packages cannot be imported. - Add the corresponding tests. Prefer using all lowercases for the flags that contain the package/module import success value. Move the `whitematteranalysis/utils.py` module to the `whitematteranalysis/utils/utils.py` module, and remove the `utils` module import from `whitematteranalysis/__init__.py` to avoid module clashes. --- ...m_assess_cluster_location_by_hemisphere.py | 5 +- bin/wm_quality_control_after_clustering.py | 19 ++-- bin/wm_quality_control_tract_overlap.py | 15 ++-- bin/wm_quality_control_tractography.py | 20 +++-- bin/wm_separate_clusters_by_hemisphere.py | 5 +- whitematteranalysis/__init__.py | 2 +- whitematteranalysis/cluster.py | 22 ++--- whitematteranalysis/congeal_multisubject.py | 27 +++--- whitematteranalysis/filter.py | 26 +++--- whitematteranalysis/laterality.py | 19 ++-- whitematteranalysis/utils/__init__.py | 0 whitematteranalysis/utils/opt_pckg.py | 87 +++++++++++++++++++ whitematteranalysis/utils/tests/__init__.py | 0 .../utils/tests/test_tripwire.py | 26 ++++++ whitematteranalysis/utils/tripwire.py | 55 ++++++++++++ whitematteranalysis/{ => utils}/utils.py | 0 16 files changed, 256 insertions(+), 72 deletions(-) create mode 100644 whitematteranalysis/utils/__init__.py create mode 100644 whitematteranalysis/utils/opt_pckg.py create mode 100644 whitematteranalysis/utils/tests/__init__.py create mode 100644 whitematteranalysis/utils/tests/test_tripwire.py create mode 100644 whitematteranalysis/utils/tripwire.py rename whitematteranalysis/{ => utils}/utils.py (100%) diff --git a/bin/wm_assess_cluster_location_by_hemisphere.py b/bin/wm_assess_cluster_location_by_hemisphere.py index ae7320d5..f7baa15f 100755 --- a/bin/wm_assess_cluster_location_by_hemisphere.py +++ b/bin/wm_assess_cluster_location_by_hemisphere.py @@ -21,6 +21,7 @@ import vtk import whitematteranalysis as wma +from whitematteranalysis.utils.utils import hemisphere_loc_name_typo_warn_msg def _build_arg_parser(): @@ -75,7 +76,7 @@ def write_mask_location_to_vtk(inpd, mask_location): array = inpointdata.GetArray(idx) if array.GetName() == 'HemisphereLocataion': warnings.warn( - wma.utils.hemisphere_loc_name_typo_warn_msg, + hemisphere_loc_name_typo_warn_msg, PendingDeprecationWarning) print(' -- HemisphereLocataion is in the input data: skip updating the vtk file.') return inpd @@ -122,7 +123,7 @@ def _read_location(_inpd): array = inpointdata.GetArray(idx) if array.GetName() == 'HemisphereLocataion': warnings.warn( - wma.utils.hemisphere_loc_name_typo_warn_msg, + hemisphere_loc_name_typo_warn_msg, PendingDeprecationWarning) flag_location = _read_location(inpd) break diff --git a/bin/wm_quality_control_after_clustering.py b/bin/wm_quality_control_after_clustering.py index 5747a12b..a47c8e64 100755 --- a/bin/wm_quality_control_after_clustering.py +++ b/bin/wm_quality_control_after_clustering.py @@ -4,21 +4,22 @@ import argparse import glob import os +import warnings import numpy as np import whitematteranalysis as wma +from whitematteranalysis.utils.opt_pckg import optional_package -HAVE_PLT = 1 -try: - import matplotlib +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") +if have_mpl: # Force matplotlib to not use any Xwindows backend. - matplotlib.use('Agg') - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n") - HAVE_PLT = 0 + matplotlib.use("Agg") +else: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot quality control data.") def _build_arg_parser(): @@ -109,7 +110,7 @@ def main(): print(cidx + 1,'\t', subjects_per_cluster[cidx],'\t', percent_subjects_per_cluster[cidx] * 100.0, file=clusters_qc_file) clusters_qc_file.close() - if HAVE_PLT: + if have_mpl: print(f"<{os.path.basename(__file__)}> Saving subjects per cluster histogram.") fig, ax = plt.subplots() counts = np.zeros(num_of_subjects+1) diff --git a/bin/wm_quality_control_tract_overlap.py b/bin/wm_quality_control_tract_overlap.py index 81b5095b..a56e8068 100755 --- a/bin/wm_quality_control_tract_overlap.py +++ b/bin/wm_quality_control_tract_overlap.py @@ -5,19 +5,20 @@ import argparse import os import time +import warnings import numpy as np import vtk import whitematteranalysis as wma +from whitematteranalysis.utils.opt_pckg import optional_package -HAVE_PLT = 1 +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") -try: - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n") - HAVE_PLT = 0 +if not have_mpl: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot quality control data.") def _build_arg_parser(): @@ -69,7 +70,7 @@ def main(): number_of_subjects = len(input_polydatas) - if HAVE_PLT: + if have_mpl: plt.figure(1) # Loop over subjects and check each diff --git a/bin/wm_quality_control_tractography.py b/bin/wm_quality_control_tractography.py index cc82483b..533e0943 100755 --- a/bin/wm_quality_control_tractography.py +++ b/bin/wm_quality_control_tractography.py @@ -5,18 +5,20 @@ import os import sys import time +import warnings import numpy as np import vtk import whitematteranalysis as wma +from whitematteranalysis.utils.opt_pckg import optional_package -HAVE_PLT = 1 -try: - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n") - HAVE_PLT = 0 +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") + +if not have_mpl: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot quality control data.") def _build_arg_parser(): @@ -175,7 +177,7 @@ def main(): spatial_qc_file.write(outstr) spatial_qc_file.close() - if HAVE_PLT: + if have_mpl: plt.figure(1) # Loop over subjects and check each @@ -272,7 +274,7 @@ def main(): spatial_qc_file.write(outstr) # Save the subject's fiber lengths - if HAVE_PLT: + if have_mpl: plt.figure(1) if lengths.size > 1: plt.hist(lengths, bins=100, histtype='step', label=subject_id) @@ -329,7 +331,7 @@ def main(): del pd3 subject_idx += 1 - if HAVE_PLT: + if have_mpl: plt.figure(1) plt.title('Histogram of fiber lengths for all subjects') plt.xlabel('fiber length (mm)') diff --git a/bin/wm_separate_clusters_by_hemisphere.py b/bin/wm_separate_clusters_by_hemisphere.py index e3d6afbf..e96e72d7 100755 --- a/bin/wm_separate_clusters_by_hemisphere.py +++ b/bin/wm_separate_clusters_by_hemisphere.py @@ -21,6 +21,7 @@ import vtk import whitematteranalysis as wma +from whitematteranalysis.utils.utils import hemisphere_loc_name_typo_warn_msg def _build_arg_parser(): @@ -86,7 +87,7 @@ def write_mask_location_to_vtk(inpd, mask_location): array = inpointdata.GetArray(idx) if array.GetName() == 'HemisphereLocataion': warnings.warn( - wma.utils.hemisphere_loc_name_typo_warn_msg, + hemisphere_loc_name_typo_warn_msg, PendingDeprecationWarning) print(' -- HemisphereLocataion is in the input data: skip updating the vtk file.') return inpd @@ -133,7 +134,7 @@ def _read_location(_inpd): array = inpointdata.GetArray(idx) if array.GetName() == 'HemisphereLocataion': warnings.warn( - wma.utils.hemisphere_loc_name_typo_warn_msg, + hemisphere_loc_name_typo_warn_msg, PendingDeprecationWarning) flag_location = _read_location(inpd) break diff --git a/whitematteranalysis/__init__.py b/whitematteranalysis/__init__.py index 78e85f2d..fdb55e0f 100644 --- a/whitematteranalysis/__init__.py +++ b/whitematteranalysis/__init__.py @@ -3,4 +3,4 @@ io, laterality, mrml, register_two_subjects, register_two_subjects_nonrigid, register_two_subjects_nonrigid_bsplines, relative_distance, - render, similarity, tract_measurement, utils) + render, similarity, tract_measurement) diff --git a/whitematteranalysis/cluster.py b/whitematteranalysis/cluster.py index 3437727e..c6b779db 100644 --- a/whitematteranalysis/cluster.py +++ b/whitematteranalysis/cluster.py @@ -30,19 +30,21 @@ from pprint import pprint +from whitematteranalysis.utils.opt_pckg import optional_package + from . import fibers, filter, io, mrml, render, similarity -HAVE_PLT = 1 -try: - import matplotlib +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") +if have_mpl: # Force matplotlib to not use any Xwindows backend. - matplotlib.use('Agg') - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n") - HAVE_PLT = 0 - + matplotlib.use("Agg") +else: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot quality control data.") + + # This did not work better. Leave here for future testing if of interest if 0: try: @@ -986,7 +988,7 @@ def output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_f clusters_qc_file.close() - if HAVE_PLT: + if have_mpl: print(f"<{os.path.basename(__file__)}> Saving subjects per cluster histogram.") fig, ax = plt.subplots() counts = np.zeros(number_of_subjects+1) diff --git a/whitematteranalysis/congeal_multisubject.py b/whitematteranalysis/congeal_multisubject.py index 0e4d291f..cca5adf8 100644 --- a/whitematteranalysis/congeal_multisubject.py +++ b/whitematteranalysis/congeal_multisubject.py @@ -11,23 +11,24 @@ class MultiSubjectRegistration import os import time +import warnings import numpy as np import vtk from joblib import Parallel, delayed -HAVE_PLT = 1 -try: - import matplotlib +import whitematteranalysis as wma +from whitematteranalysis.utils.opt_pckg import optional_package - # Force matplotlib to not use any Xwindows backend. - matplotlib.use('Agg') - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot objectives.\n") - HAVE_PLT = 0 +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") -import whitematteranalysis as wma +if have_mpl: + # Force matplotlib to not use any Xwindows backend. + matplotlib.use("Agg") +else: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot objectives.") class MultiSubjectRegistration: @@ -304,7 +305,7 @@ def iterate(self): functions_per_subject = list() objective_changes_per_subject = list() decreases = list() - if HAVE_PLT: + if have_mpl: plt.close('all') plt.figure(0) plt.title('Iteration '+str(self.total_iterations)+' Objective Values for All Subjects') @@ -326,7 +327,7 @@ def iterate(self): objective_total_after += objectives[0] objective_changes_per_subject.append(diff) sidx += 1 - if HAVE_PLT: + if have_mpl: plt.figure(0) plt.plot(objectives, 'o-', label=sidx) @@ -342,7 +343,7 @@ def iterate(self): print("Iteration:", self.total_iterations, "TOTAL objective change:", total_change) print("Iteration:", self.total_iterations, "PERCENT objective change:", percent_change) - if HAVE_PLT: + if have_mpl: plt.figure(0) if self.mode == "Nonrigid": fname_fig_base = "iteration_%05d_sigma_%03d_grid_%03d" % (self.total_iterations, self.sigma, self.nonrigid_grid_resolution) diff --git a/whitematteranalysis/filter.py b/whitematteranalysis/filter.py index e9e6ece0..e46e9138 100644 --- a/whitematteranalysis/filter.py +++ b/whitematteranalysis/filter.py @@ -20,20 +20,24 @@ """ import os +import warnings import numpy as np import vtk -try: - from joblib import Parallel, delayed - USE_PARALLEL = 1 -except ImportError: - USE_PARALLEL = 0 - print(f"<{os.path.basename(__file__)}> Failed to import joblib, cannot multiprocess.") - print(f"<{os.path.basename(__file__)}> Please install joblib for this functionality.") +from whitematteranalysis.utils.opt_pckg import optional_package from . import fibers, similarity +joblib, have_joblib, _ = optional_package("joblib") +Parallel, _, _ = optional_package("joblib.Parallel") +delayed, _, _ = optional_package("joblib.delayed") + +if not have_joblib: + warnings.warn(joblib._msg) + warnings.warn("Cannot multiprocess.") + + verbose = 0 @@ -604,7 +608,7 @@ def remove_outliers(inpd, min_fiber_distance, n_jobs=0, distance_method ='Mean') min_fiber_distance = min_fiber_distance * min_fiber_distance # pairwise distance matrix - if USE_PARALLEL and n_jobs > 0: + if have_joblib and n_jobs > 0: distances = Parallel(n_jobs=n_jobs, verbose=1)( delayed(similarity.fiber_distance)( fiber_array.get_fiber(lidx), @@ -678,7 +682,7 @@ def smooth(inpd, fiber_distance_sigma = 25, points_per_fiber=30, n_jobs=2, upper print(f"<{os.path.basename(__file__)}> Computing pairwise distances...") # pairwise distance matrix - if USE_PARALLEL: + if have_joblib: distances = Parallel(n_jobs=n_jobs, verbose=1)( delayed(similarity.fiber_distance)( current_fiber_array.get_fiber(lidx), @@ -802,7 +806,7 @@ def anisotropic_smooth(inpd, fiber_distance_threshold, points_per_fiber=30, n_jo done[np.nonzero(np.array(curr_count) >= cluster_max)] = 1 # pairwise distance matrix - if USE_PARALLEL: + if have_joblib: distances = Parallel(n_jobs=n_jobs, verbose=1)( delayed(similarity.fiber_distance)( current_fiber_array.get_fiber(lidx), @@ -967,7 +971,7 @@ def laplacian_of_gaussian(inpd, fiber_distance_sigma = 25, points_per_fiber=30, fiber_indices = list(range(0, fiber_array.number_of_fibers)) # pairwise distance matrix - if USE_PARALLEL: + if have_joblib: distances = Parallel(n_jobs=n_jobs, verbose=1)( delayed(similarity.fiber_distance)( fiber_array.get_fiber(lidx), diff --git a/whitematteranalysis/laterality.py b/whitematteranalysis/laterality.py index 184d0026..d0a67598 100644 --- a/whitematteranalysis/laterality.py +++ b/whitematteranalysis/laterality.py @@ -21,22 +21,25 @@ class ComputeWhiteMatterLaterality """ import os +import warnings import numpy as np import vtk -try: - from joblib import Parallel, delayed - USE_PARALLEL = 1 -except ImportError: - USE_PARALLEL = 0 - print(f"<{os.path.basename(__file__)}> Failed to import joblib, cannot multiprocess.") - print(f"<{os.path.basename(__file__)}> Please install joblib for this functionality.") +from whitematteranalysis.utils.opt_pckg import optional_package from . import filter, similarity from .fibers import FiberArray from .io import LateralityResults +joblib, have_joblib, _ = optional_package("joblib") +Parallel, _, _ = optional_package("joblib.Parallel") +delayed, _, _ = optional_package("joblib.delayed") + +if not have_joblib: + warnings.warn(joblib._msg) + warnings.warn("Cannot multiprocess.") + def compute_laterality_index(left, right, idx=None): ''' Compute laterality index from left and right hemisphere quantities.''' @@ -171,7 +174,7 @@ def compute(self, input_vtk_polydata): print(f"<{os.path.basename(__file__)}> Starting to compute laterality indices") # run the computation, either in parallel or not - if (USE_PARALLEL & (self.parallel_jobs > 1)): + if (have_joblib & (self.parallel_jobs > 1)): if self.verbose: print(f"<{os.path.basename(__file__)}> Starting parallel code. Processes:", \ self.parallel_jobs) diff --git a/whitematteranalysis/utils/__init__.py b/whitematteranalysis/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/whitematteranalysis/utils/opt_pckg.py b/whitematteranalysis/utils/opt_pckg.py new file mode 100644 index 00000000..4b74e1f7 --- /dev/null +++ b/whitematteranalysis/utils/opt_pckg.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- + +"""Methods to support optional packages.""" + +import importlib + +try: + import pytest +except ImportError: + have_pytest = False +else: + have_pytest = True + +from whitematteranalysis.utils.tripwire import TripWire + + +def optional_package(name, trip_msg=None): + """Return package-like thing and module setup for package ``name``. + + Parameters + ---------- + name : str + Package name. + trip_msg : None or str + Message to be shown when the specified package cannot be imported. + + Returns + ------- + pckg : package, module or ``TripWire`` instance + If the package can be imported, return it. Otherwise, return an object + raising an error when accessed. + have_pkg : bool + True if import for package was successful, False otherwise. + module_setup : function + Callable usually set as ``setup_module`` in calling namespace, to allow + skipping tests. + + Examples + -------- + Typical use: + + >>> from whitematteranalysis.utils.opt_pckg import optional_package + >>> pckg, have_pckg, _setup_module = optional_package("not_a_package") + + In this case the package doesn't exist, and therefore: + + >>> have_pckg + False + + and + + >>> pckg.some_function() #doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TripWireError: We need package not_a_package for these functions, but + ``import not_a_package`` raised an ImportError + + If the module does exist - we get the module + + >>> pckg, _, _ = optional_package("os") + >>> hasattr(pckg, "path") + True + + Or a submodule if that's what we asked for + + >>> subpckg, _, _ = optional_package("os.path") + >>> hasattr(subpckg, "dirname") + True + """ + + try: + pckg = importlib.import_module(name) + except ImportError: + pass + else: # import worked + # top level module + return pckg, True, lambda: None + if trip_msg is None: + trip_msg = ( + f"{name} needed, but ``import f{name}`` raised an ``ImportError``.") + pckg = TripWire(trip_msg) + + def setup_module(): + if have_pytest: + pytest.mark.skip(f"No {name} for these tests.") + + return pckg, False, setup_module diff --git a/whitematteranalysis/utils/tests/__init__.py b/whitematteranalysis/utils/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/whitematteranalysis/utils/tests/test_tripwire.py b/whitematteranalysis/utils/tests/test_tripwire.py new file mode 100644 index 00000000..632e0d6a --- /dev/null +++ b/whitematteranalysis/utils/tests/test_tripwire.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from numpy.testing import assert_raises + +from whitematteranalysis.utils.tripwire import (TripWire, TripWireError, + is_tripwire) + + +def test_is_tripwire(): + assert not is_tripwire(object()) + assert is_tripwire(TripWire("some message")) + + +def test_tripwire(): + # Test tripwire object + not_a_package_name = TripWire("Do not have not_a_package") + assert_raises(TripWireError, getattr, not_a_package_name, "do_something") + assert_raises(TripWireError, not_a_package_name) + # Check AttributeError can be checked too + try: + not_a_package_name.__wrapped__ + except TripWireError as err: + assert isinstance(err, AttributeError) + else: + raise RuntimeError("No error raised, but expected.") diff --git a/whitematteranalysis/utils/tripwire.py b/whitematteranalysis/utils/tripwire.py new file mode 100644 index 00000000..0ee9dcc9 --- /dev/null +++ b/whitematteranalysis/utils/tripwire.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- + +class TripWireError(AttributeError): + """Class to raise an exception for missing modules or other misfortunes..""" + + +def is_tripwire(obj): + """Returns True if ``obj`` appears to be a ``TripWire`` object. + + Examples + -------- + >>> is_tripwire(object()) + False + >>> is_tripwire(TripWire("some message")) + True + """ + + try: + obj.any_attribute + except TripWireError: + return True + except Exception: + pass + return False + + +class TripWire: + """Class raising error if used. + + Standard use is to proxy modules that could not be imported. + + Examples + -------- + >>> try: + ... import not_a_package + ... except ImportError: + ... not_a_package_name = TripWire("Do not have not_a_package_name") + >>> not_a_package_name.do_something("with argument") #doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TripWireError: Do not have not_a_package_name + """ + + def __init__(self, msg): + self._msg = msg + + def __getattr__(self, attr_name): + """Raise informative error accessing attributes. + """ + raise TripWireError(self._msg) + + def __call__(self, *args, **kwargs): + """Raise informative error while calling. + """ + raise TripWireError(self._msg) diff --git a/whitematteranalysis/utils.py b/whitematteranalysis/utils/utils.py similarity index 100% rename from whitematteranalysis/utils.py rename to whitematteranalysis/utils/utils.py