From 3673dbd808b9d22cb26b468a4c889e614dfe557c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 21 Oct 2023 17:33:55 -0400 Subject: [PATCH] ENH: Improve importing optional packages and modules Improve importing optional packages and modules: - Add helper classes and utils to take care of trying to import optional packages, and to provide meaningful information when such packages cannot be imported. - Add the corresponding tests. Prefer using all lowercases for the flags that contain the package/module import success value. --- bin/wm_quality_control_after_clustering.py | 19 ++-- bin/wm_quality_control_tract_overlap.py | 15 ++-- bin/wm_quality_control_tractography.py | 20 +++-- whitematteranalysis/cluster.py | 22 ++--- whitematteranalysis/congeal_multisubject.py | 27 +++--- whitematteranalysis/filter.py | 26 +++--- whitematteranalysis/laterality.py | 19 ++-- whitematteranalysis/utils/__init__.py | 0 whitematteranalysis/utils/opt_pckg.py | 87 +++++++++++++++++++ whitematteranalysis/utils/tests/__init__.py | 0 .../utils/tests/test_tripwire.py | 27 ++++++ whitematteranalysis/utils/tripwire.py | 55 ++++++++++++ 12 files changed, 250 insertions(+), 67 deletions(-) create mode 100644 whitematteranalysis/utils/__init__.py create mode 100644 whitematteranalysis/utils/opt_pckg.py create mode 100644 whitematteranalysis/utils/tests/__init__.py create mode 100644 whitematteranalysis/utils/tests/test_tripwire.py create mode 100644 whitematteranalysis/utils/tripwire.py diff --git a/bin/wm_quality_control_after_clustering.py b/bin/wm_quality_control_after_clustering.py index 60904115..4cffd350 100644 --- a/bin/wm_quality_control_after_clustering.py +++ b/bin/wm_quality_control_after_clustering.py @@ -4,21 +4,22 @@ import argparse import glob import os +import warnings import numpy import whitematteranalysis as wma +from whitematteranalysis.utils.opt_pckg import optional_package -HAVE_PLT = 1 -try: - import matplotlib +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") +if have_mpl: # Force matplotlib to not use any Xwindows backend. - matplotlib.use('Agg') - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n") - HAVE_PLT = 0 + matplotlib.use("Agg") +else: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot quality control data.") def _build_arg_parser(): @@ -113,7 +114,7 @@ def main(): print(cidx + 1,'\t', subjects_per_cluster[cidx],'\t', percent_subjects_per_cluster[cidx] * 100.0, file=clusters_qc_file) clusters_qc_file.close() - if HAVE_PLT: + if have_mpl: print(f"<{os.path.basename(__file__)}> Saving subjects per cluster histogram.") fig, ax = plt.subplots() counts = numpy.zeros(num_of_subjects+1) diff --git a/bin/wm_quality_control_tract_overlap.py b/bin/wm_quality_control_tract_overlap.py index ce2a322a..814d5c5d 100755 --- a/bin/wm_quality_control_tract_overlap.py +++ b/bin/wm_quality_control_tract_overlap.py @@ -5,19 +5,20 @@ import argparse import os import time +import warnings import numpy import vtk import whitematteranalysis as wma +from whitematteranalysis.utils.opt_pckg import optional_package -HAVE_PLT = 1 +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") -try: - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n") - HAVE_PLT = 0 +if not have_mpl: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot quality control data.") def _build_arg_parser(): @@ -69,7 +70,7 @@ def main(): number_of_subjects = len(input_polydatas) - if HAVE_PLT: + if have_mpl: plt.figure(1) # Loop over subjects and check each diff --git a/bin/wm_quality_control_tractography.py b/bin/wm_quality_control_tractography.py index 2f865f1a..936f6dd9 100755 --- a/bin/wm_quality_control_tractography.py +++ b/bin/wm_quality_control_tractography.py @@ -5,18 +5,20 @@ import os import sys import time +import warnings import numpy import vtk import whitematteranalysis as wma +from whitematteranalysis.utils.opt_pckg import optional_package -HAVE_PLT = 1 -try: - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n") - HAVE_PLT = 0 +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") + +if not have_mpl: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot quality control data.") def _build_arg_parser(): @@ -175,7 +177,7 @@ def main(): spatial_qc_file.write(outstr) spatial_qc_file.close() - if HAVE_PLT: + if have_mpl: plt.figure(1) # Loop over subjects and check each @@ -272,7 +274,7 @@ def main(): spatial_qc_file.write(outstr) # Save the subject's fiber lengths - if HAVE_PLT: + if have_mpl: plt.figure(1) if lengths.size > 1: plt.hist(lengths, bins=100, histtype='step', label=subject_id) @@ -329,7 +331,7 @@ def main(): del pd3 subject_idx += 1 - if HAVE_PLT: + if have_mpl: plt.figure(1) plt.title('Histogram of fiber lengths for all subjects') plt.xlabel('fiber length (mm)') diff --git a/whitematteranalysis/cluster.py b/whitematteranalysis/cluster.py index 3a4652aa..b577e5b5 100644 --- a/whitematteranalysis/cluster.py +++ b/whitematteranalysis/cluster.py @@ -30,19 +30,21 @@ from pprint import pprint +from whitematteranalysis.utils.opt_pckg import optional_package + from . import fibers, filter, io, mrml, render, similarity -HAVE_PLT = 1 -try: - import matplotlib +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") +if have_mpl: # Force matplotlib to not use any Xwindows backend. - matplotlib.use('Agg') - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot quality control data.\n") - HAVE_PLT = 0 - + matplotlib.use("Agg") +else: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot quality control data.") + + # This did not work better. Leave here for future testing if of interest if 0: try: @@ -986,7 +988,7 @@ def output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_f clusters_qc_file.close() - if HAVE_PLT: + if have_mpl: print(f"<{os.path.basename(__file__)}> Saving subjects per cluster histogram.") fig, ax = plt.subplots() counts = numpy.zeros(number_of_subjects+1) diff --git a/whitematteranalysis/congeal_multisubject.py b/whitematteranalysis/congeal_multisubject.py index f0463d5a..e0481617 100644 --- a/whitematteranalysis/congeal_multisubject.py +++ b/whitematteranalysis/congeal_multisubject.py @@ -11,23 +11,24 @@ class MultiSubjectRegistration import os import time +import warnings import numpy import vtk from joblib import Parallel, delayed -HAVE_PLT = 1 -try: - import matplotlib +import whitematteranalysis as wma +from whitematteranalysis.utils.opt_pckg import optional_package - # Force matplotlib to not use any Xwindows backend. - matplotlib.use('Agg') - import matplotlib.pyplot as plt -except: - print(f"<{os.path.basename(__file__)}> Error importing matplotlib.pyplot package, can't plot objectives.\n") - HAVE_PLT = 0 +matplotlib, have_mpl, _ = optional_package("matplotlib") +plt, _, _ = optional_package("matplotlib.pyplot") -import whitematteranalysis as wma +if have_mpl: + # Force matplotlib to not use any Xwindows backend. + matplotlib.use("Agg") +else: + warnings.warn(matplotlib._msg) + warnings.warn("Cannot plot objectives.") class MultiSubjectRegistration: @@ -304,7 +305,7 @@ def iterate(self): functions_per_subject = list() objective_changes_per_subject = list() decreases = list() - if HAVE_PLT: + if have_mpl: plt.close('all') plt.figure(0) plt.title('Iteration '+str(self.total_iterations)+' Objective Values for All Subjects') @@ -326,7 +327,7 @@ def iterate(self): objective_total_after += objectives[0] objective_changes_per_subject.append(diff) sidx += 1 - if HAVE_PLT: + if have_mpl: plt.figure(0) plt.plot(objectives, 'o-', label=sidx) @@ -342,7 +343,7 @@ def iterate(self): print("Iteration:", self.total_iterations, "TOTAL objective change:", total_change) print("Iteration:", self.total_iterations, "PERCENT objective change:", percent_change) - if HAVE_PLT: + if have_mpl: plt.figure(0) if self.mode == "Nonrigid": fname_fig_base = "iteration_%05d_sigma_%03d_grid_%03d" % (self.total_iterations, self.sigma, self.nonrigid_grid_resolution) diff --git a/whitematteranalysis/filter.py b/whitematteranalysis/filter.py index a1137c4a..2522e9d0 100644 --- a/whitematteranalysis/filter.py +++ b/whitematteranalysis/filter.py @@ -20,20 +20,24 @@ """ import os +import warnings import numpy import vtk -try: - from joblib import Parallel, delayed - USE_PARALLEL = 1 -except ImportError: - USE_PARALLEL = 0 - print(f"<{os.path.basename(__file__)}> Failed to import joblib, cannot multiprocess.") - print(f"<{os.path.basename(__file__)}> Please install joblib for this functionality.") +from whitematteranalysis.utils.opt_pckg import optional_package from . import fibers, similarity +joblib, have_joblib, _ = optional_package("joblib") +Parallel, _, _ = optional_package("joblib.Parallel") +delayed, _, _ = optional_package("joblib.delayed") + +if not have_joblib: + warnings.warn(joblib._msg) + warnings.warn("Cannot multiprocess.") + + verbose = 0 @@ -604,7 +608,7 @@ def remove_outliers(inpd, min_fiber_distance, n_jobs=0, distance_method ='Mean') min_fiber_distance = min_fiber_distance * min_fiber_distance # pairwise distance matrix - if USE_PARALLEL and n_jobs > 0: + if have_joblib and n_jobs > 0: distances = Parallel(n_jobs=n_jobs, verbose=1)( delayed(similarity.fiber_distance)( fiber_array.get_fiber(lidx), @@ -678,7 +682,7 @@ def smooth(inpd, fiber_distance_sigma = 25, points_per_fiber=30, n_jobs=2, upper print(f"<{os.path.basename(__file__)}> Computing pairwise distances...") # pairwise distance matrix - if USE_PARALLEL: + if have_joblib: distances = Parallel(n_jobs=n_jobs, verbose=1)( delayed(similarity.fiber_distance)( current_fiber_array.get_fiber(lidx), @@ -802,7 +806,7 @@ def anisotropic_smooth(inpd, fiber_distance_threshold, points_per_fiber=30, n_jo done[numpy.nonzero(numpy.array(curr_count) >= cluster_max)] = 1 # pairwise distance matrix - if USE_PARALLEL: + if have_joblib: distances = Parallel(n_jobs=n_jobs, verbose=1)( delayed(similarity.fiber_distance)( current_fiber_array.get_fiber(lidx), @@ -967,7 +971,7 @@ def laplacian_of_gaussian(inpd, fiber_distance_sigma = 25, points_per_fiber=30, fiber_indices = list(range(0, fiber_array.number_of_fibers)) # pairwise distance matrix - if USE_PARALLEL: + if have_joblib: distances = Parallel(n_jobs=n_jobs, verbose=1)( delayed(similarity.fiber_distance)( fiber_array.get_fiber(lidx), diff --git a/whitematteranalysis/laterality.py b/whitematteranalysis/laterality.py index dd1cbd29..e9400cf4 100644 --- a/whitematteranalysis/laterality.py +++ b/whitematteranalysis/laterality.py @@ -21,22 +21,25 @@ class ComputeWhiteMatterLaterality """ import os +import warnings import numpy import vtk -try: - from joblib import Parallel, delayed - USE_PARALLEL = 1 -except ImportError: - USE_PARALLEL = 0 - print(f"<{os.path.basename(__file__)}> Failed to import joblib, cannot multiprocess.") - print(f"<{os.path.basename(__file__)}> Please install joblib for this functionality.") +from whitematteranalysis.utils.opt_pckg import optional_package from . import filter, similarity from .fibers import FiberArray from .io import LateralityResults +joblib, have_joblib, _ = optional_package("joblib") +Parallel, _, _ = optional_package("joblib.Parallel") +delayed, _, _ = optional_package("joblib.delayed") + +if not have_joblib: + warnings.warn(joblib._msg) + warnings.warn("Cannot multiprocess.") + def compute_laterality_index(left, right, idx=None): ''' Compute laterality index from left and right hemisphere quantities.''' @@ -171,7 +174,7 @@ def compute(self, input_vtk_polydata): print(f"<{os.path.basename(__file__)}> Starting to compute laterality indices") # run the computation, either in parallel or not - if (USE_PARALLEL & (self.parallel_jobs > 1)): + if (have_joblib & (self.parallel_jobs > 1)): if self.verbose: print(f"<{os.path.basename(__file__)}> Starting parallel code. Processes:", \ self.parallel_jobs) diff --git a/whitematteranalysis/utils/__init__.py b/whitematteranalysis/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/whitematteranalysis/utils/opt_pckg.py b/whitematteranalysis/utils/opt_pckg.py new file mode 100644 index 00000000..4b74e1f7 --- /dev/null +++ b/whitematteranalysis/utils/opt_pckg.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- + +"""Methods to support optional packages.""" + +import importlib + +try: + import pytest +except ImportError: + have_pytest = False +else: + have_pytest = True + +from whitematteranalysis.utils.tripwire import TripWire + + +def optional_package(name, trip_msg=None): + """Return package-like thing and module setup for package ``name``. + + Parameters + ---------- + name : str + Package name. + trip_msg : None or str + Message to be shown when the specified package cannot be imported. + + Returns + ------- + pckg : package, module or ``TripWire`` instance + If the package can be imported, return it. Otherwise, return an object + raising an error when accessed. + have_pkg : bool + True if import for package was successful, False otherwise. + module_setup : function + Callable usually set as ``setup_module`` in calling namespace, to allow + skipping tests. + + Examples + -------- + Typical use: + + >>> from whitematteranalysis.utils.opt_pckg import optional_package + >>> pckg, have_pckg, _setup_module = optional_package("not_a_package") + + In this case the package doesn't exist, and therefore: + + >>> have_pckg + False + + and + + >>> pckg.some_function() #doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TripWireError: We need package not_a_package for these functions, but + ``import not_a_package`` raised an ImportError + + If the module does exist - we get the module + + >>> pckg, _, _ = optional_package("os") + >>> hasattr(pckg, "path") + True + + Or a submodule if that's what we asked for + + >>> subpckg, _, _ = optional_package("os.path") + >>> hasattr(subpckg, "dirname") + True + """ + + try: + pckg = importlib.import_module(name) + except ImportError: + pass + else: # import worked + # top level module + return pckg, True, lambda: None + if trip_msg is None: + trip_msg = ( + f"{name} needed, but ``import f{name}`` raised an ``ImportError``.") + pckg = TripWire(trip_msg) + + def setup_module(): + if have_pytest: + pytest.mark.skip(f"No {name} for these tests.") + + return pckg, False, setup_module diff --git a/whitematteranalysis/utils/tests/__init__.py b/whitematteranalysis/utils/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/whitematteranalysis/utils/tests/test_tripwire.py b/whitematteranalysis/utils/tests/test_tripwire.py new file mode 100644 index 00000000..d4d5bc3e --- /dev/null +++ b/whitematteranalysis/utils/tests/test_tripwire.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import unittest + +from whitematteranalysis.utils.tripwire import (TripWire, TripWireError, + is_tripwire) + + +def test_is_tripwire(): + assert not is_tripwire(object()) + assert is_tripwire(TripWire("some message")) + + +def test_tripwire(): + # Test tripwire object + not_a_package_name = TripWire("Do not have not_a_package") + unittest.TestCase.assertRaises( + TripWireError, getattr, not_a_package_name, "do_something") + unittest.TestCase.assertRaises(TripWireError, not_a_package_name) + # Check AttributeError can be checked too + try: + not_a_package_name.__wrapped__ + except TripWireError as err: + assert_true(isinstance(err, AttributeError)) + else: + raise RuntimeError("No error raised, but expected.") diff --git a/whitematteranalysis/utils/tripwire.py b/whitematteranalysis/utils/tripwire.py new file mode 100644 index 00000000..0ee9dcc9 --- /dev/null +++ b/whitematteranalysis/utils/tripwire.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- + +class TripWireError(AttributeError): + """Class to raise an exception for missing modules or other misfortunes..""" + + +def is_tripwire(obj): + """Returns True if ``obj`` appears to be a ``TripWire`` object. + + Examples + -------- + >>> is_tripwire(object()) + False + >>> is_tripwire(TripWire("some message")) + True + """ + + try: + obj.any_attribute + except TripWireError: + return True + except Exception: + pass + return False + + +class TripWire: + """Class raising error if used. + + Standard use is to proxy modules that could not be imported. + + Examples + -------- + >>> try: + ... import not_a_package + ... except ImportError: + ... not_a_package_name = TripWire("Do not have not_a_package_name") + >>> not_a_package_name.do_something("with argument") #doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TripWireError: Do not have not_a_package_name + """ + + def __init__(self, msg): + self._msg = msg + + def __getattr__(self, attr_name): + """Raise informative error accessing attributes. + """ + raise TripWireError(self._msg) + + def __call__(self, *args, **kwargs): + """Raise informative error while calling. + """ + raise TripWireError(self._msg)