Skip to content

Commit

Permalink
add get_modules convenience (#541)
Browse files Browse the repository at this point in the history
  • Loading branch information
bjlittle authored Nov 29, 2023
1 parent d449616 commit a84b2b4
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 29 deletions.
21 changes: 9 additions & 12 deletions src/geovista/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@

import importlib
import pathlib
import pkgutil
from shutil import rmtree

import click
from click_default_group import DefaultGroup
import pyvista as pv

from . import examples as scripts
from . import logger
from ._version import version as __version__
from .cache import CACHE, GEOVISTA_POOCH_MUTE, pooch_mute
from .common import get_modules
from .config import resources
from .geoplotter import GeoPlotter
from .report import Report
Expand All @@ -41,9 +40,7 @@

FG_COLOUR: str = "cyan"

SCRIPTS: list[str] = [ALL] + [
submodule.name for submodule in pkgutil.iter_modules(scripts.__path__)
]
EXAMPLES: list[str] = [ALL] + get_modules("geovista.examples")


def _download_group(
Expand Down Expand Up @@ -375,7 +372,7 @@ def collect(prefix: str) -> list[str]:
@click.option(
"-r",
"--run",
type=click.Choice(SCRIPTS, case_sensitive=False),
type=click.Choice(EXAMPLES, case_sensitive=False),
is_flag=False,
help="Execute the example.",
)
Expand All @@ -388,13 +385,13 @@ def collect(prefix: str) -> list[str]:
def examples(run_all: bool, show: bool, run: bool, verbose: bool) -> None:
"""Execute a geovista example script."""
# account for the initial "all" option
n_scripts = len(SCRIPTS) - 1
n_examples = len(EXAMPLES) - 1

if show:
click.echo("Names of available examples:")
width = len(str(n_scripts))
for i, script in enumerate(SCRIPTS[1:]):
click.echo(f"[{i + 1:0{width}d}/{n_scripts}] ", nl=False)
width = len(str(n_examples))
for i, script in enumerate(EXAMPLES[1:]):
click.echo(f"[{i + 1:0{width}d}/{n_examples}] ", nl=False)
click.secho(f"{script}", fg="green")
click.echo("\n👍 All done!")
return
Expand All @@ -405,8 +402,8 @@ def examples(run_all: bool, show: bool, run: bool, verbose: bool) -> None:
logger.setLevel("INFO")

if run_all:
for i, script in enumerate(SCRIPTS[1:]):
msg = f"Running example {script!r} ({i+1} of {n_scripts}) ..."
for i, script in enumerate(EXAMPLES[1:]):
msg = f"Running example {script!r} ({i+1} of {n_examples}) ..."
click.secho(msg, fg="green")
module = importlib.import_module(f"geovista.examples.{script}")
if verbose:
Expand Down
45 changes: 45 additions & 0 deletions src/geovista/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from collections.abc import Iterable
from enum import Enum
import importlib
import pkgutil
import sys
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -54,6 +56,7 @@
"cast_UnstructuredGrid_to_PolyData",
"distance",
"from_cartesian",
"get_modules",
"nan_mask",
"point_cloud",
"sanitize_data",
Expand Down Expand Up @@ -511,6 +514,48 @@ def from_cartesian(
return np.vstack(data).T if stacked else np.array(data)


def get_modules(root: str, base: bool | None = True) -> list[str]:
"""Find all submodule names relative to the `root` package.
Recursively searches down from the `root` to find all child (leaf) modules.
The names of the modules will be relative to the `root`.
Parameters
----------
root : str
The name (dot notation) of the top level package to search under.
e.g., ``geovista.examples``.
base : bool, optional
Flag the top level `root` package, which will then remove the `root` prefix
from all packages found and sort them alphabetically.
Returns
-------
list of str
The sorted list of child module names, relative to the `root`.
Notes
-----
.. versionadded:: 0.5.0
"""
modules, pkgs = [], []

for info in pkgutil.iter_modules(importlib.import_module(root).__path__):
name = f"{root}.{info.name}"
container = pkgs if info.ispkg else modules
container.append(name)

for pkg in pkgs:
modules.extend(get_modules(pkg, base=False))

if base:
modules = sorted([name.split(f"{root}.")[1] for name in modules])

return modules


def nan_mask(data: ArrayLike) -> np.ndarray:
"""Replace any masked array values with NaNs.
Expand Down
30 changes: 30 additions & 0 deletions tests/common/test_get_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2021, GeoVista Contributors.
#
# This file is part of GeoVista and is distributed under the 3-Clause BSD license.
# See the LICENSE file in the package root directory for licensing details.

"""Unit-tests for :func:`geovista.common.get_modules`."""
from __future__ import annotations

import importlib
from os import sep
from pathlib import Path

from geovista.common import get_modules


def test_package_walk():
"""Test walk of examples package for all underlying example modules."""
package = "geovista.examples"
result = get_modules(package)

module = importlib.import_module(package)
path = Path(module.__path__[0])
fnames = path.rglob("*.py")
expected = [
str(fname.relative_to(path)).replace(".py", "").replace(sep, ".")
for fname in fnames
if fname.name != "__init__.py"
]

assert result == sorted(expected)
33 changes: 16 additions & 17 deletions tests/plotting/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,21 @@
import importlib
import os
from pathlib import Path
import pkgutil
import shutil

import pytest
import pyvista as pv

import geovista as gv
from geovista.cache import CACHE
import geovista.examples
from geovista.common import get_modules

# determine whether executing on a GHA runner
# https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables
CI: bool = os.environ.get("CI", "false").lower() == "true"

# construct list of example script names
SCRIPTS = sorted(
[submodule.name for submodule in pkgutil.iter_modules(gv.examples.__path__)]
)
# construct list of example module names relative to "geovista.examples"
EXAMPLES = get_modules("geovista.examples")

# prepare geovista/pyvista for off-screen image testing
pv.global_theme.load_theme(pv.plotting.themes._TestingTheme())
Expand All @@ -49,7 +46,7 @@
# create the symbolic link to the pooch cache
cache_dir.symlink_to(base_dir)

# individual GHA CI test case exceptions to the default image tolerances
# individual GHA CI example test case exceptions to the default image tolerances
thresholds = {
"from_points__orca_cloud": {"warning_value": 202.0},
"from_points__orca_cloud_eqc": {"warning_value": 250.0},
Expand All @@ -58,21 +55,23 @@


@pytest.mark.image()
@pytest.mark.parametrize("script", SCRIPTS)
def test(script, verify_image_cache):
@pytest.mark.parametrize("example", EXAMPLES)
def test(example, verify_image_cache):
"""Image test the example scripts."""
# apply individual test case image tolerance exceptions only when
# executing within a remote GHA runner environment
if CI and script in thresholds:
for attr, value in thresholds[script].items():
if CI and example in thresholds:
for attr, value in thresholds[example].items():
setattr(verify_image_cache, attr, value)

verify_image_cache.test_name = f"test_{script}"
# import the example script
module = importlib.import_module(f"geovista.examples.{script}")
# if necessary, download and cache missing script base image (expected) to
# replace dot notation with double underscores
safe = example.replace(".", "__")
verify_image_cache.test_name = f"test_{safe}"
# import the example module
module = importlib.import_module(f"geovista.examples.{example}")
# if necessary, download and cache missing example base image (expected) to
# compare with the actual test image generated via pytest-pyvista plugin
if verify_image_cache.add_missing_images is False:
_ = CACHE.fetch(f"tests/images/{script}.png")
# execute the example script for image testing
_ = CACHE.fetch(f"tests/images/{safe}.png")
# execute the example module for image testing
module.main()

0 comments on commit a84b2b4

Please sign in to comment.