diff --git a/.flake8 b/.flake8 index 855c8d8a..03e9737e 100644 --- a/.flake8 +++ b/.flake8 @@ -35,3 +35,4 @@ exclude = build esmf_regrid/__init__.py esmf_regrid/tests/results + benchmarks/* diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 9f00653a..d71bce37 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -73,9 +73,7 @@ jobs: - name: Benchmark script run: | if ${{ github.event_name != 'pull_request' }}; then export COMPARE="HEAD~"; else export COMPARE="origin/${{ github.base_ref }}"; fi; - nox --session=tests --install-only - export DATA_GEN_PYTHON=$(realpath $(find .nox -path "*tests/bin/python")) - nox --session="benchmarks(branch)" -- "${COMPARE}" + python benchmarks/bm_runner.py branch ${COMPARE} - name: Archive ASV results uses: actions/upload-artifact@v4 diff --git a/.github/workflows/ci-manifest.yml b/.github/workflows/ci-manifest.yml index 5b1675b9..f92b2207 100644 --- a/.github/workflows/ci-manifest.yml +++ b/.github/workflows/ci-manifest.yml @@ -20,4 +20,4 @@ concurrency: jobs: manifest: name: "check-manifest" - uses: scitools/workflows/.github/workflows/ci-manifest.yml@2024.04.3 + uses: scitools/workflows/.github/workflows/ci-manifest.yml@2024.06.5 diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a91dca6..fd53b5c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,40 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](http://semver.org/). +## [0.10] - 2024-05-31 + +### Added + +- [PR#357](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/357) + Added support for saving and loading of `ESMFAreaWeighted`, `ESMFBilinear` + and `ESMFNearest` regridders. + [@stephenworsley](https://github.com/stephenworsley) +- [PR#319](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/319) + Added `CITATION.cff`. + [@bjlittle](https://github.com/bjlittle) + +### Changed + +- [PR#361](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/361) + Moved the code for running benchmarks to `bm_runner.py` in line with iris + benchmarks. + [@stephenworsley](https://github.com/stephenworsley) +- [PR#293](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/293) + Enumerated method and normtype input. + [@ESadek-MO](https://github.com/ESadek-MO) + +### Fixed + +- [PR#239](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/239) + Ensured dtype is preserved by regridding. + [@stephenworsley](https://github.com/stephenworsley) +- [PR#353](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/353) + Fixed a bug which caused errors with ESMF versions 8.6 and higher. + [@stephenworsley](https://github.com/stephenworsley) +- [PR#338](https://github.com/SciTools-incubator/iris-esmf-regrid/pull/338) + Fixed a potential memory leak when creating regridders. + [@stephenworsley](https://github.com/stephenworsley) + ## [0.9] - 2023-11-03 ### Added diff --git a/benchmarks/README.md b/benchmarks/README.md index aa7ec2ac..6c71eeec 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -16,16 +16,27 @@ raising a ❌ failure. installed, as well as Nox (see [Benchmark environments](#benchmark-environments)). -[iris-esmf-regrid's noxfile](../noxfile.py) includes a `benchmarks` session -that provides conveniences for setting up before benchmarking, and can also -replicate the CI run locally. See the session docstring for detail. +The benchmark runner ([bm_runner.py](./bm_runner.py)) provides conveniences for +common benchmark setup and run tasks, including replicating the automated +overnight run locally. See `python bm_runner.py --help` for detail. + +A significant portion of benchmark run time is environment management. Run-time +can be reduced by placing the benchmark environment on the same file system as +your +[Conda package cache](https://conda.io/projects/conda/en/latest/user-guide/configuration/use-condarc.html#specify-pkg-directories), +if it is not already. You can achieve this by either: + +- Temporarily reconfiguring `delegated_env_commands` and `delegated_env_parent` + in [asv.conf.json](asv.conf.json) to reference a location on the same file + system as the Conda package cache. +- Moving your Iris repo to the same file system as the Conda package cache. ### Environment variables * `DATA_GEN_PYTHON` - required - path to a Python executable that can be used to generate benchmark test objects/files; see -[Data generation](#data-generation). The Nox session sets this automatically, -but will defer to any value already set in the shell. +[Data generation](#data-generation). The benchmark runner sets this +automatically, but will defer to any value already set in the shell. * `BENCHMARK_DATA` - optional - path to a directory for benchmark synthetic test data, which the benchmark scripts will create if it doesn't already exist. Defaults to `/benchmarks/.data/` if not set. Note that some of @@ -34,7 +45,7 @@ plan accordingly. * `ON_DEMAND_BENCHMARKS` - optional - when set (to any value): benchmarks decorated with `@on_demand_benchmark` are included in the ASV run. Usually coupled with the ASV `--bench` argument to only run the benchmark(s) of -interest. Is set during the Nox `sperf` session. +interest. Is set during the benchmark runner `sperf` sub-commands. ### Reducing run time diff --git a/benchmarks/asv_delegated_conda.py b/benchmarks/asv_delegated_conda.py index 30810309..ab2478e3 100644 --- a/benchmarks/asv_delegated_conda.py +++ b/benchmarks/asv_delegated_conda.py @@ -189,6 +189,11 @@ def copy_asv_files(src_parent: Path, dst_parent: Path) -> None: # Record new environment information in properties. self._update_info() + def _run_conda(self, args, env=None): + # TODO: remove after airspeed-velocity/asv#1397 is merged and released. + args = ["--yes" if arg == "--force" else arg for arg in args] + return super()._run_conda(args, env) + def checkout_project(self, repo: Repo, commit_hash: str) -> None: """Check out the working tree of the project at given commit hash.""" super().checkout_project(repo, commit_hash) diff --git a/benchmarks/bm_runner.py b/benchmarks/bm_runner.py new file mode 100644 index 00000000..8588cd22 --- /dev/null +++ b/benchmarks/bm_runner.py @@ -0,0 +1,314 @@ +"""Argparse conveniences for executing common types of benchmark runs.""" + +from abc import ABC, abstractmethod +import argparse +from argparse import ArgumentParser +from datetime import datetime +from importlib import import_module +from os import environ +from pathlib import Path +import re +import shlex +import subprocess +from tempfile import NamedTemporaryFile +from typing import Literal + +# The threshold beyond which shifts are 'notable'. See `asv compare`` docs +# for more. +COMPARE_FACTOR = 1.2 + +BENCHMARKS_DIR = Path(__file__).parent +ROOT_DIR = BENCHMARKS_DIR.parent +# Storage location for reports used in GitHub actions. +GH_REPORT_DIR = ROOT_DIR.joinpath(".github", "workflows", "benchmark_reports") + +# Common ASV arguments for all run_types except `custom`. +ASV_HARNESS = "run {posargs} --attribute rounds=4 --interleave-rounds --show-stderr" + + +def _echo(echo_string: str): + # Use subprocess for printing to reduce chance of printing out of sequence + # with the subsequent calls. + subprocess.run(["echo", f"BM_RUNNER DEBUG: {echo_string}"]) + + +def _subprocess_runner(args, asv=False, **kwargs): + # Avoid permanent modifications if the same arguments are used more than once. + args = args.copy() + kwargs = kwargs.copy() + if asv: + args.insert(0, "asv") + kwargs["cwd"] = BENCHMARKS_DIR + _echo(" ".join(args)) + kwargs.setdefault("check", True) + return subprocess.run(args, **kwargs) + + +def _subprocess_runner_capture(args, **kwargs) -> str: + result = _subprocess_runner(args, capture_output=True, **kwargs) + return result.stdout.decode().rstrip() + + +def _check_requirements(package: str) -> None: + try: + import_module(package) + except ImportError as exc: + message = ( + f"No {package} install detected. Benchmarks can only " + f"be run in an environment including {package}." + ) + raise Exception(message) from exc + + +def _prep_data_gen_env() -> None: + """Create or access a separate, unchanging environment for generating test data.""" + python_version = "3.10" + data_gen_var = "DATA_GEN_PYTHON" + if data_gen_var in environ: + _echo("Using existing data generation environment.") + else: + _echo("Setting up the data generation environment ...") + # Get Nox to build an environment for the `tests` session, but don't + # run the session. Will reuse a cached environment if appropriate. + _subprocess_runner( + [ + "nox", + f"--noxfile={ROOT_DIR / 'noxfile.py'}", + "--session=tests", + "--install-only", + f"--python={python_version}", + ] + ) + # Find the environment built above, set it to be the data generation + # environment. + data_gen_python = next( + (ROOT_DIR / ".nox").rglob(f"tests*/bin/python{python_version}") + ).resolve() + environ[data_gen_var] = str(data_gen_python) + + _echo("Data generation environment ready.") + + +def _setup_common() -> None: + _check_requirements("asv") + _check_requirements("nox") + + _prep_data_gen_env() + + _echo("Setting up ASV ...") + _subprocess_runner(["machine", "--yes"], asv=True) + + _echo("Setup complete.") + + +def _asv_compare(*commits: str) -> None: + """Run through a list of commits comparing each one to the next.""" + commits = [commit[:8] for commit in commits] + for i in range(len(commits) - 1): + before = commits[i] + after = commits[i + 1] + asv_command = shlex.split( + f"compare {before} {after} --factor={COMPARE_FACTOR} --split" + ) + + comparison = _subprocess_runner_capture(asv_command, asv=True) + _echo(comparison) + shifts = _subprocess_runner_capture([*asv_command, "--only-changed"], asv=True) + _echo(shifts) + + +class _SubParserGenerator(ABC): + """Convenience for holding all the necessary argparse info in 1 place.""" + + name: str = NotImplemented + description: str = NotImplemented + epilog: str = NotImplemented + + def __init__(self, subparsers: ArgumentParser.add_subparsers) -> None: + self.subparser: ArgumentParser = subparsers.add_parser( + self.name, + description=self.description, + epilog=self.epilog, + formatter_class=argparse.RawTextHelpFormatter, + ) + self.add_arguments() + self.add_asv_arguments() + self.subparser.set_defaults(func=self.func) + + @abstractmethod + def add_arguments(self) -> None: + """All custom self.subparser.add_argument() calls.""" + _ = NotImplemented + + def add_asv_arguments(self) -> None: + self.subparser.add_argument( + "asv_args", + nargs=argparse.REMAINDER, + help="Any number of arguments to pass down to the ASV benchmark command.", + ) + + @staticmethod + @abstractmethod + def func(args: argparse.Namespace): + """Return when the subparser is parsed. + + `func` is then called, performing the user's selected sub-command. + + """ + _ = args + return NotImplemented + + +class Branch(_SubParserGenerator): + """Class for parsing and running the 'branch' argument.""" + + name = "branch" + description = ( + "Benchmarks the two commits,``HEAD``, and ``HEAD``'s merge-base with the " + "input **base_branch**. If running on GitHub Actions: HEAD will be " + "GitHub's merge commit and merge-base will be the merge target. Performance " + "comparisons will be posted in the CI run which will fail if regressions " + "exceed the tolerance.\n" + "Uses `asv run`." + ) + epilog = ( + "e.g. python bm_runner.py branch upstream/main\n" + "e.g. python bm_runner.py branch upstream/main --bench=regridding" + ) + + def add_arguments(self) -> None: + self.subparser.add_argument( + "base_branch", + type=str, + help="A branch that has the merge-base with ``HEAD`` - ``HEAD`` will be benchmarked against that merge-base.", + ) + + @staticmethod + def func(args: argparse.Namespace) -> None: + _setup_common() + + git_command = shlex.split("git rev-parse HEAD") + head_sha = _subprocess_runner_capture(git_command)[:8] + + git_command = shlex.split(f"git merge-base {head_sha} {args.base_branch}") + merge_base = _subprocess_runner_capture(git_command)[:8] + + with NamedTemporaryFile("w") as hashfile: + hashfile.writelines([merge_base, "\n", head_sha]) + hashfile.flush() + commit_range = f"HASHFILE:{hashfile.name}" + asv_command = shlex.split(ASV_HARNESS.format(posargs=commit_range)) + _subprocess_runner([*asv_command, *args.asv_args], asv=True) + + _asv_compare(merge_base, head_sha) + + +class SPerf(_SubParserGenerator): + """Class for parsing and running the 'sperf' argument.""" + + name = "sperf" + description = ( + "Run the on-demand Sperf suite of benchmarks (measuring " + "scalability) for the ``HEAD`` of ``upstream/main`` only, " + "and publish the results to the input **publish_dir**, within a " + "unique subdirectory for this run.\n" + "Uses `asv run`." + ) + epilog = ( + "e.g. python bm_runner.py sperf my_publish_dir\n" + "e.g. python bm_runner.py sperf my_publish_dir --bench=regridding" + ) + + def add_arguments(self) -> None: + self.subparser.add_argument( + "publish_dir", + type=str, + help="HTML results will be published to a sub-dir in this dir.", + ) + + @staticmethod + def func(args: argparse.Namespace) -> None: + _setup_common() + + publish_dir = Path(args.publish_dir) + if not publish_dir.is_dir(): + message = f"Input 'publish directory' is not a directory: {publish_dir}" + raise NotADirectoryError(message) + publish_subdir = ( + publish_dir / f"sperf_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + publish_subdir.mkdir() + + # Activate on demand benchmarks (C/SPerf are deactivated for + # 'standard' runs). + environ["ON_DEMAND_BENCHMARKS"] = "True" + commit_range = "upstream/main^!" + + asv_command = ( + ASV_HARNESS.format(posargs=commit_range) + " --bench=.*Scalability.*" + ) + + # Only do a single round. + asv_command = shlex.split(re.sub(r"rounds=\d", "rounds=1", asv_command)) + try: + _subprocess_runner([*asv_command, *args.asv_args], asv=True) + except subprocess.CalledProcessError as err: + # C/SPerf benchmarks are much bigger than the CI ones: + # Don't fail the whole run if memory blows on 1 benchmark. + # ASV produces return code of 2 if the run includes crashes. + if err.returncode != 2: + raise + + asv_command = shlex.split(f"publish {commit_range} --html-dir={publish_subdir}") + _subprocess_runner(asv_command, asv=True) + + # Print completion message. + location = BENCHMARKS_DIR / ".asv" + _echo( + f'New ASV results for "sperf".\n' + f'See "{publish_subdir}",' + f'\n or JSON files under "{location / "results"}".' + ) + + +class Custom(_SubParserGenerator): + """Class for parsing and running the 'custom' argument.""" + + name = "custom" + description = ( + "Run ASV with the input **ASV sub-command**, without any preset " + "arguments - must all be supplied by the user. So just like running " + "ASV manually, with the convenience of re-using the runner's " + "scripted setup steps." + ) + epilog = "e.g. python bm_runner.py custom continuous a1b23d4 HEAD --quick" + + def add_arguments(self) -> None: + self.subparser.add_argument( + "asv_sub_command", + type=str, + help="The ASV command to run.", + ) + + @staticmethod + def func(args: argparse.Namespace) -> None: + _setup_common() + _subprocess_runner([args.asv_sub_command, *args.asv_args], asv=True) + + +def main(): + parser = ArgumentParser( + description="Run the Iris performance benchmarks (using Airspeed Velocity).", + epilog="More help is available within each sub-command.", + ) + subparsers = parser.add_subparsers(required=True) + + for gen in (Branch, SPerf, Custom): + _ = gen(subparsers).subparser + + parsed = parser.parse_args() + parsed.func(parsed) + + +if __name__ == "__main__": + main() diff --git a/docs/src/userguide/examples.rst b/docs/src/userguide/examples.rst index b66e33b7..06564656 100644 --- a/docs/src/userguide/examples.rst +++ b/docs/src/userguide/examples.rst @@ -31,10 +31,10 @@ Saving and Loading a Regridder A regridder can be set up for reuse, this saves time performing the computationally expensive initialisation process:: - from esmf_regrid.experimental.unstructured_scheme import MeshToGridESMFRegridder + from esmf_regrid.experimental.unstructured_scheme import ESMFAreaWeighted # Initialise the regridder with a source mesh and target grid. - regridder = MeshToGridESMFRegridder(source_mesh_cube, target_grid_cube) + regridder = ESMFAreaWeighted().regridder(source_mesh_cube, target_grid_cube) # use the initialised regridder to regrid the data from the source cube # onto a cube with the same grid as `target_grid_cube`. diff --git a/docs/src/userguide/scheme_comparison.rst b/docs/src/userguide/scheme_comparison.rst index e29a78d6..a9bb4a1f 100644 --- a/docs/src/userguide/scheme_comparison.rst +++ b/docs/src/userguide/scheme_comparison.rst @@ -61,10 +61,12 @@ These were formerly the only way to do regridding with a source or target cube defined on an unstructured mesh. These are less flexible and require that the source/target be defined on a grid/mesh. Unlike the above regridders whose method is fixed, these regridders take a ``method`` keyword -of ``conservative``, ``bilinear`` or ``nearest``. While most of the -functionality in these regridders have been ported into the above schemes and -regridders, these remain the only regridders capable of being saved and loaded by -:mod:`esmf_regrid.experimental.io`. +of ``conservative``, ``bilinear`` or ``nearest``. All the +functionality in these regridders has now been ported into the above schemes and +regridders. Before version 0.10, these were the only regridders capable of being +saved and loaded by :mod:`esmf_regrid.experimental.io`, so while the above generic +regridders are recomended, these regridders are still available for the sake of +consistency with regridders saved from older versions. Overview: Miscellaneous Functions diff --git a/esmf_regrid/__init__.py b/esmf_regrid/__init__.py index 850de104..8ba175fc 100644 --- a/esmf_regrid/__init__.py +++ b/esmf_regrid/__init__.py @@ -12,4 +12,4 @@ from .schemes import * -__version__ = "0.10.dev0" +__version__ = "0.11.dev0" diff --git a/esmf_regrid/esmf_regridder.py b/esmf_regrid/esmf_regridder.py index 07d81071..5322c415 100644 --- a/esmf_regrid/esmf_regridder.py +++ b/esmf_regrid/esmf_regridder.py @@ -134,6 +134,12 @@ def __init__( self.esmf_version = None self.weight_matrix = precomputed_weights + def _out_dtype(self, in_dtype): + """Return the expected output dtype for a given input dtype.""" + weight_dtype = self.weight_matrix.dtype + out_dtype = (np.ones(1, dtype=in_dtype) * np.ones(1, dtype=weight_dtype)).dtype + return out_dtype + def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): """ Perform regridding on an array of data. @@ -175,12 +181,13 @@ def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): extra_size = max(1, np.prod(extra_shape)) src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array)) weight_sums = self.weight_matrix @ src_inverted_mask + out_dtype = self._out_dtype(src_array.dtype) # Set the minimum mdtol to be slightly higher than 0 to account for rounding # errors. mdtol = max(mdtol, 1e-8) tgt_mask = weight_sums > 1 - mdtol masked_weight_sums = weight_sums * tgt_mask - normalisations = np.ones([self.tgt.size, extra_size]) + normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype) if norm_type == Constants.NormType.FRACAREA: normalisations[tgt_mask] /= masked_weight_sums[tgt_mask] elif norm_type == Constants.NormType.DSTAREA: diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index 624820c3..490d0485 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -1,5 +1,7 @@ """Provides load/save functions for regridders.""" +from contextlib import contextmanager + import iris from iris.coords import AuxCoord from iris.cube import Cube, CubeList @@ -13,9 +15,19 @@ GridToMeshESMFRegridder, MeshToGridESMFRegridder, ) +from esmf_regrid.schemes import ( + ESMFAreaWeightedRegridder, + ESMFBilinearRegridder, + ESMFNearestRegridder, + GridRecord, + MeshRecord, +) SUPPORTED_REGRIDDERS = [ + ESMFAreaWeightedRegridder, + ESMFBilinearRegridder, + ESMFNearestRegridder, GridToMeshESMFRegridder, MeshToGridESMFRegridder, ] @@ -34,6 +46,8 @@ MDTOL = "mdtol" METHOD = "method" RESOLUTION = "resolution" +SOURCE_RESOLUTION = "src_resolution" +TARGET_RESOLUTION = "tgt_resolution" def _add_mask_to_cube(mask, cube, name): @@ -43,18 +57,63 @@ def _add_mask_to_cube(mask, cube, name): cube.add_aux_coord(mask_coord, list(range(cube.ndim))) +@contextmanager +def _managed_var_name(src_cube, tgt_cube): + src_coord_names = [] + src_mesh_coords = [] + if src_cube.mesh is not None: + src_mesh = src_cube.mesh + src_mesh_coords = src_mesh.coords() + for coord in src_mesh_coords: + src_coord_names.append(coord.var_name) + tgt_coord_names = [] + tgt_mesh_coords = [] + if tgt_cube.mesh is not None: + tgt_mesh = tgt_cube.mesh + tgt_mesh_coords = tgt_mesh.coords() + for coord in tgt_mesh_coords: + tgt_coord_names.append(coord.var_name) + + try: + for coord in src_mesh_coords: + coord.var_name = "_".join([SOURCE_NAME, "mesh", coord.name()]) + for coord in tgt_mesh_coords: + coord.var_name = "_".join([TARGET_NAME, "mesh", coord.name()]) + yield None + finally: + for coord, var_name in zip(src_mesh_coords, src_coord_names): + coord.var_name = var_name + for coord, var_name in zip(tgt_mesh_coords, tgt_coord_names): + coord.var_name = var_name + + +def _clean_var_names(cube): + cube.var_name = None + for coord in cube.coords(): + coord.var_name = None + if cube.mesh is not None: + cube.mesh.var_name = None + for coord in cube.mesh.coords(): + coord.var_name = None + for con in cube.mesh.connectivities(): + con.var_name = None + + def save_regridder(rg, filename): """ Save a regridder scheme instance. - Saves either a - :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder` - or a - :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`. + Saves any of the regridder classes, i.e. + :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, + :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, + :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, + :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or + :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. + . Parameters ---------- - rg : :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder` or :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder` + rg : :class:`~esmf_regrid.schemes._ESMFRegridder` The regridder instance to save. filename : str The file name to save to. @@ -76,28 +135,56 @@ def _standard_grid_cube(grid, name): cube.add_aux_coord(grid[1], [0, 1]) return cube - if regridder_type == "GridToMeshESMFRegridder": + def _standard_mesh_cube(mesh, location, name): + mesh_coords = mesh.to_MeshCoords(location) + data = np.zeros(mesh_coords[0].points.shape[0]) + cube = Cube(data, var_name=name, long_name=name) + for coord in mesh_coords: + cube.add_aux_coord(coord, 0) + return cube + + if regridder_type in [ + "ESMFAreaWeightedRegridder", + "ESMFBilinearRegridder", + "ESMFNearestRegridder", + ]: + src_grid = rg._src + if isinstance(src_grid, GridRecord): + src_cube = _standard_grid_cube( + (src_grid.grid_y, src_grid.grid_x), SOURCE_NAME + ) + elif isinstance(src_grid, MeshRecord): + src_mesh, src_location = src_grid + src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME) + else: + raise ValueError("Improper type for `rg._src`.") + _add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME) + + tgt_grid = rg._tgt + if isinstance(tgt_grid, GridRecord): + tgt_cube = _standard_grid_cube( + (tgt_grid.grid_y, tgt_grid.grid_x), TARGET_NAME + ) + elif isinstance(tgt_grid, MeshRecord): + tgt_mesh, tgt_location = tgt_grid + tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME) + else: + raise ValueError("Improper type for `rg._tgt`.") + _add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME) + elif regridder_type == "GridToMeshESMFRegridder": src_grid = (rg.grid_y, rg.grid_x) src_cube = _standard_grid_cube(src_grid, SOURCE_NAME) _add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME) tgt_mesh = rg.mesh tgt_location = rg.location - tgt_mesh_coords = tgt_mesh.to_MeshCoords(tgt_location) - tgt_data = np.zeros(tgt_mesh_coords[0].points.shape[0]) - tgt_cube = Cube(tgt_data, var_name=TARGET_NAME, long_name=TARGET_NAME) - for coord in tgt_mesh_coords: - tgt_cube.add_aux_coord(coord, 0) + tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME) _add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME) elif regridder_type == "MeshToGridESMFRegridder": src_mesh = rg.mesh src_location = rg.location - src_mesh_coords = src_mesh.to_MeshCoords(src_location) - src_data = np.zeros(src_mesh_coords[0].points.shape[0]) - src_cube = Cube(src_data, var_name=SOURCE_NAME, long_name=SOURCE_NAME) - for coord in src_mesh_coords: - src_cube.add_aux_coord(coord, 0) + src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME) _add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME) tgt_grid = (rg.grid_y, rg.grid_x) @@ -112,7 +199,18 @@ def _standard_grid_cube(grid, name): method = str(check_method(rg.method).name) - resolution = rg.resolution + if regridder_type in ["GridToMeshESMFRegridder", "MeshToGridESMFRegridder"]: + resolution = rg.resolution + src_resolution = None + tgt_resolution = None + elif regridder_type == "ESMFAreaWeightedRegridder": + resolution = None + src_resolution = rg.src_resolution + tgt_resolution = rg.tgt_resolution + else: + resolution = None + src_resolution = None + tgt_resolution = None weight_matrix = rg.regridder.weight_matrix reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix) @@ -141,6 +239,10 @@ def _standard_grid_cube(grid, name): } if resolution is not None: attributes[RESOLUTION] = resolution + if src_resolution is not None: + attributes[SOURCE_RESOLUTION] = src_resolution + if tgt_resolution is not None: + attributes[TARGET_RESOLUTION] = tgt_resolution weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME) row_coord = AuxCoord( @@ -158,17 +260,14 @@ def _standard_grid_cube(grid, name): long_name=WEIGHTS_SHAPE_NAME, ) - # Avoid saving bug by placing the mesh cube second. - # TODO: simplify this when this bug is fixed in iris. - if regridder_type == "GridToMeshESMFRegridder": + # Save cubes while ensuring var_names do not conflict for the sake of consistency. + with _managed_var_name(src_cube, tgt_cube): cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube]) - elif regridder_type == "MeshToGridESMFRegridder": - cube_list = CubeList([tgt_cube, src_cube, weights_cube, weight_shape_cube]) - for cube in cube_list: - cube.attributes = attributes + for cube in cube_list: + cube.attributes = attributes - iris.fileformats.netcdf.save(cube_list, filename) + iris.fileformats.netcdf.save(cube_list, filename) def load_regridder(filename): @@ -194,7 +293,9 @@ def load_regridder(filename): # Extract the source, target and metadata information. src_cube = cubes.extract_cube(SOURCE_NAME) + _clean_var_names(src_cube) tgt_cube = cubes.extract_cube(TARGET_NAME) + _clean_var_names(tgt_cube) weights_cube = cubes.extract_cube(WEIGHTS_NAME) weight_shape_cube = cubes.extract_cube(WEIGHTS_SHAPE_NAME) @@ -210,8 +311,14 @@ def load_regridder(filename): ) resolution = weights_cube.attributes.get(RESOLUTION, None) + src_resolution = weights_cube.attributes.get(SOURCE_RESOLUTION, None) + tgt_resolution = weights_cube.attributes.get(TARGET_RESOLUTION, None) if resolution is not None: resolution = int(resolution) + if src_resolution is not None: + src_resolution = int(src_resolution) + if tgt_resolution is not None: + tgt_resolution = int(tgt_resolution) # Reconstruct the weight matrix. weight_data = weights_cube.data @@ -234,18 +341,25 @@ def load_regridder(filename): use_tgt_mask = False if scheme is GridToMeshESMFRegridder: - resolution_keyword = "src_resolution" + resolution_keyword = SOURCE_RESOLUTION + kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} elif scheme is MeshToGridESMFRegridder: - resolution_keyword = "tgt_resolution" + resolution_keyword = TARGET_RESOLUTION + kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} + elif scheme is ESMFAreaWeightedRegridder: + kwargs = { + SOURCE_RESOLUTION: src_resolution, + TARGET_RESOLUTION: tgt_resolution, + "mdtol": mdtol, + } + elif scheme is ESMFBilinearRegridder: + kwargs = {"mdtol": mdtol} else: - raise NotImplementedError - kwargs = {resolution_keyword: resolution} + kwargs = {} regridder = scheme( src_cube, tgt_cube, - mdtol=mdtol, - method=method, precomputed_weights=weight_matrix, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, diff --git a/esmf_regrid/schemes.py b/esmf_regrid/schemes.py index 5c759e38..ebfd274c 100644 --- a/esmf_regrid/schemes.py +++ b/esmf_regrid/schemes.py @@ -284,7 +284,9 @@ def _regrid_along_dims(data, regridder, dims, num_out_dims, mdtol): return result -def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs): +def _map_complete_blocks( + src, func, active_dims, out_sizes, *args, dtype=None, **kwargs +): """ Apply a function to complete blocks. @@ -308,6 +310,8 @@ def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs): Dimensions that cannot be chunked. out_sizes : tuple of int Output size of dimensions that cannot be chunked. + dtype : type, optional + Type of the output array, if not given, the dtype of src is used. Returns ------- @@ -320,6 +324,8 @@ def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs): return func(src.data, *args, **kwargs) data = src.lazy_data() + if dtype is None: + dtype = data.dtype # Ensure dims are not chunked in_chunks = list(data.chunks) @@ -382,7 +388,7 @@ def _map_complete_blocks(src, func, active_dims, out_sizes, *args, **kwargs): chunks=out_chunks, drop_axis=dropped_dims, new_axis=new_axis, - dtype=src.dtype, + dtype=dtype, **kwargs, ) @@ -566,6 +572,8 @@ def _regrid_rectilinear_to_rectilinear__perform(src_cube, regrid_info, mdtol): grid_x, grid_y = regrid_info.target regridder = regrid_info.regridder + out_dtype = regridder._out_dtype(src_cube.dtype) + # Apply regrid to all the chunks of src_cube, ensuring first that all # chunks cover the entire horizontal plane (otherwise they would break # the regrid function). @@ -583,6 +591,7 @@ def _regrid_rectilinear_to_rectilinear__perform(src_cube, regrid_info, mdtol): dims=[grid_x_dim, grid_y_dim], num_out_dims=2, mdtol=mdtol, + dtype=out_dtype, ) new_cube = _create_cube( @@ -645,6 +654,8 @@ def _regrid_unstructured_to_rectilinear__perform(src_cube, regrid_info, mdtol): grid_x, grid_y = regrid_info.target regridder = regrid_info.regridder + out_dtype = regridder._out_dtype(src_cube.dtype) + # Apply regrid to all the chunks of src_cube, ensuring first that all # chunks cover the entire horizontal plane (otherwise they would break # the regrid function). @@ -662,6 +673,7 @@ def _regrid_unstructured_to_rectilinear__perform(src_cube, regrid_info, mdtol): dims=[mesh_dim], num_out_dims=2, mdtol=mdtol, + dtype=out_dtype, ) new_cube = _create_cube( @@ -748,6 +760,8 @@ def _regrid_rectilinear_to_unstructured__perform(src_cube, regrid_info, mdtol): else: raise NotImplementedError(f"Unrecognised location {location}.") + out_dtype = regridder._out_dtype(src_cube.dtype) + # Apply regrid to all the chunks of src_cube, ensuring first that all # chunks cover the entire horizontal plane (otherwise they would break # the regrid function). @@ -760,6 +774,7 @@ def _regrid_rectilinear_to_unstructured__perform(src_cube, regrid_info, mdtol): dims=[grid_x_dim, grid_y_dim], num_out_dims=1, mdtol=mdtol, + dtype=out_dtype, ) new_cube = _create_cube( @@ -832,6 +847,8 @@ def _regrid_unstructured_to_unstructured__perform(src_cube, regrid_info, mdtol): mesh, location = regrid_info.target regridder = regrid_info.regridder + out_dtype = regridder._out_dtype(src_cube.dtype) + if location == "face": face_node = mesh.face_node_connectivity chunk_shape = (face_node.shape[face_node.location_axis],) @@ -849,6 +866,7 @@ def _regrid_unstructured_to_unstructured__perform(src_cube, regrid_info, mdtol): dims=[mesh_dim], num_out_dims=1, mdtol=mdtol, + dtype=out_dtype, ) new_cube = _create_cube( @@ -975,6 +993,8 @@ def regridder( self, src_grid, tgt_grid, + src_resolution=None, + tgt_resolution=None, use_src_mask=None, use_tgt_mask=None, tgt_location="face", @@ -993,6 +1013,11 @@ def regridder( :class:`~iris.experimental.ugrid.Mesh` defining the target. If this cube has a grid defined by latitude/longitude coordinates, those coordinates must have bounds. + src_resolution, tgt_resolution : int, optional + If present, represents the amount of latitude slices per source/target cell + given to ESMF for calculation. If resolution is set, ``src`` and ``tgt`` + respectively must have strictly increasing bounds (bounds may be transposed + plus or minus 360 degrees to make the bounds strictly increasing). use_src_mask : :obj:`~numpy.typing.ArrayLike` or bool, optional Array describing which elements :mod:`esmpy` will ignore on the src_grid. If True, the mask will be derived from src_grid. @@ -1030,6 +1055,8 @@ def regridder( src_grid, tgt_grid, mdtol=self.mdtol, + src_resolution=src_resolution, + tgt_resolution=tgt_resolution, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, tgt_location="face", @@ -1482,8 +1509,10 @@ def __init__( if tgt_location is not "face". """ kwargs = dict() + self.src_resolution = src_resolution if src_resolution is not None: kwargs["src_resolution"] = src_resolution + self.tgt_resolution = tgt_resolution if tgt_resolution is not None: kwargs["tgt_resolution"] = tgt_resolution if tgt_location is not None and tgt_location != "face": diff --git a/esmf_regrid/tests/conftest.py b/esmf_regrid/tests/conftest.py new file mode 100644 index 00000000..eec32b1d --- /dev/null +++ b/esmf_regrid/tests/conftest.py @@ -0,0 +1,22 @@ +"""Common testing infrastructure.""" + +import pytest + + +@pytest.fixture(params=["float32", "float64"]) +def in_dtype(request): + """Fixture for controlling dtype.""" + return request.param + + +@pytest.fixture( + params=[ + ("grid", "grid"), + ("grid", "mesh"), + ("mesh", "grid"), + ("mesh", "mesh"), + ] +) +def src_tgt_types(request): + """Fixture for controlling type of source and target.""" + return request.param diff --git a/esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py b/esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py index c29d6df8..81be1302 100644 --- a/esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py +++ b/esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py @@ -226,3 +226,29 @@ def _get_points(bounds): (weights_dict["weights"], (weights_dict["row_dst"], weights_dict["col_src"])) ) assert np.allclose(result.toarray(), expected_weights.toarray()) + + +def test_Regridder_dtype_handling(): + """ + Basic test for :meth:`~esmf_regrid.esmf_regridder.Regridder.regrid`. + + Tests that dtype is handled as expected. + """ + lon, lat, lon_bounds, lat_bounds = make_grid_args(2, 3) + src_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) + + lon, lat, lon_bounds, lat_bounds = make_grid_args(3, 2) + tgt_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) + + # Set up the regridder with precomputed weights. + rg_64 = Regridder(src_grid, tgt_grid, precomputed_weights=_expected_weights()) + weights_32 = _expected_weights().astype(np.float32) + rg_32 = Regridder(src_grid, tgt_grid, precomputed_weights=weights_32) + + src_32 = np.ones([3, 2], dtype=np.float32) + src_64 = np.ones([3, 2], dtype=np.float64) + + assert rg_64.regrid(src_64).dtype == np.float64 + assert rg_64.regrid(src_32).dtype == np.float64 + assert rg_32.regrid(src_64).dtype == np.float64 + assert rg_32.regrid(src_32).dtype == np.float32 diff --git a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index c16adcef..f17fd941 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -4,7 +4,13 @@ from numpy import ma import pytest -from esmf_regrid import Constants +from esmf_regrid import ( + Constants, + ESMFAreaWeighted, + ESMFAreaWeightedRegridder, + ESMFBilinear, + ESMFNearest, +) from esmf_regrid.experimental.io import load_regridder, save_regridder from esmf_regrid.experimental.unstructured_scheme import ( GridToMeshESMFRegridder, @@ -21,6 +27,7 @@ def _make_grid_to_mesh_regridder( method=Constants.Method.CONSERVATIVE, + regridder=GridToMeshESMFRegridder, resolution=None, grid_dims=1, circular=True, @@ -60,20 +67,21 @@ def _make_grid_to_mesh_regridder( use_src_mask = False use_tgt_mask = False - rg = GridToMeshESMFRegridder( - src, - tgt, - method=method, - mdtol=0.5, - src_resolution=resolution, - use_src_mask=use_src_mask, - use_tgt_mask=use_tgt_mask, - ) + kwargs = { + "mdtol": 0.5, + "src_resolution": resolution, + "use_src_mask": use_src_mask, + "use_tgt_mask": use_tgt_mask, + } + if regridder == GridToMeshESMFRegridder: + kwargs["method"] = method + rg = regridder(src, tgt, **kwargs) return rg, src def _make_mesh_to_grid_regridder( method=Constants.Method.CONSERVATIVE, + regridder=MeshToGridESMFRegridder, resolution=None, grid_dims=1, circular=True, @@ -83,7 +91,10 @@ def _make_mesh_to_grid_regridder( src_lats = 4 tgt_lons = 5 tgt_lats = 6 - lon_bounds = (-180, 180) + if circular: + lon_bounds = (-180, 180) + else: + lon_bounds = (-180, 170) lat_bounds = (-90, 90) if grid_dims == 1: tgt = _grid_cube(tgt_lons, tgt_lats, lon_bounds, lat_bounds, circular=circular) @@ -110,38 +121,57 @@ def _make_mesh_to_grid_regridder( use_src_mask = False use_tgt_mask = False - rg = MeshToGridESMFRegridder( + kwargs = { + "mdtol": 0.5, + "tgt_resolution": resolution, + "use_src_mask": use_src_mask, + "use_tgt_mask": use_tgt_mask, + } + if regridder == MeshToGridESMFRegridder: + kwargs["method"] = method + rg = regridder( src, tgt, - method=method, - mdtol=0.5, - tgt_resolution=resolution, - use_src_mask=use_src_mask, - use_tgt_mask=use_tgt_mask, + **kwargs, ) return rg, src +def _compare_ignoring_var_names(x, y): + old_var_name = x.var_name + x.var_name = y.var_name + assert x == y + x.var_name = old_var_name + + @pytest.mark.parametrize( - "method", + "method,regridder", [ - Constants.Method.CONSERVATIVE, - Constants.Method.BILINEAR, - Constants.Method.NEAREST, + (Constants.Method.CONSERVATIVE, GridToMeshESMFRegridder), + (Constants.Method.BILINEAR, GridToMeshESMFRegridder), + (Constants.Method.NEAREST, GridToMeshESMFRegridder), + (None, ESMFAreaWeightedRegridder), ], ) -def test_GridToMeshESMFRegridder_round_trip(tmp_path, method): - """Test save/load round tripping for `GridToMeshESMFRegridder`.""" - original_rg, src = _make_grid_to_mesh_regridder(method=method, circular=True) +def test_grid_to_mesh_round_trip(tmp_path, method, regridder): + """Test save/load round tripping for grid to mesh regridding.""" + original_rg, src = _make_grid_to_mesh_regridder( + method=method, regridder=regridder, circular=True + ) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - assert original_rg.location == loaded_rg.location + if regridder == GridToMeshESMFRegridder: + assert original_rg.location == loaded_rg.location + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + else: + assert original_rg._tgt.location == loaded_rg._tgt.location + _compare_ignoring_var_names(original_rg._src[0], loaded_rg._src[0]) + _compare_ignoring_var_names(original_rg._src[1], loaded_rg._src[1]) assert original_rg.method == loaded_rg.method assert original_rg.mdtol == loaded_rg.mdtol - assert original_rg.grid_x == loaded_rg.grid_x - assert original_rg.grid_y == loaded_rg.grid_y # TODO: uncomment when iris mesh comparison becomes available. # assert original_rg.mesh == loaded_rg.mesh @@ -181,25 +211,52 @@ def test_GridToMeshESMFRegridder_round_trip(tmp_path, method): original_res_rg.regridder.src.resolution == loaded_res_rg.regridder.src.resolution ) + elif regridder == ESMFAreaWeightedRegridder: + assert original_rg.src_resolution == loaded_rg.src_resolution + original_res_rg, _ = _make_grid_to_mesh_regridder( + regridder=regridder, resolution=8 + ) + res_filename = tmp_path / "regridder_res.nc" + save_regridder(original_res_rg, res_filename) + loaded_res_rg = load_regridder(str(res_filename)) + assert original_res_rg.src_resolution == loaded_res_rg.src_resolution + assert ( + original_res_rg.regridder.src.resolution + == loaded_res_rg.regridder.src.resolution + ) # Ensure grid equality for non-circular coords. - original_nc_rg, _ = _make_grid_to_mesh_regridder(method=method, circular=False) + original_nc_rg, src = _make_grid_to_mesh_regridder( + method=method, regridder=regridder, circular=True + ) nc_filename = tmp_path / "non_circular_regridder.nc" save_regridder(original_nc_rg, nc_filename) loaded_nc_rg = load_regridder(str(nc_filename)) - assert original_nc_rg.grid_x == loaded_nc_rg.grid_x - assert original_nc_rg.grid_y == loaded_nc_rg.grid_y + if regridder == GridToMeshESMFRegridder: + _compare_ignoring_var_names(original_nc_rg.grid_x, loaded_nc_rg.grid_x) + _compare_ignoring_var_names(original_nc_rg.grid_y, loaded_nc_rg.grid_y) + else: + _compare_ignoring_var_names(original_nc_rg._src[0], loaded_nc_rg._src[0]) + _compare_ignoring_var_names(original_nc_rg._src[1], loaded_nc_rg._src[1]) -def test_GridToMeshESMFRegridder_curvilinear_round_trip(tmp_path): - """Test save/load round tripping for `GridToMeshESMFRegridder`.""" - original_rg, src = _make_grid_to_mesh_regridder(grid_dims=2) +@pytest.mark.parametrize( + "regridder", + [GridToMeshESMFRegridder, ESMFAreaWeightedRegridder], +) +def test_grid_to_mesh_curvilinear_round_trip(tmp_path, regridder): + """Test save/load round tripping for grid to mesh regridding.""" + original_rg, src = _make_grid_to_mesh_regridder(regridder=regridder, grid_dims=2) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - assert original_rg.grid_x == loaded_rg.grid_x - assert original_rg.grid_y == loaded_rg.grid_y + if regridder == GridToMeshESMFRegridder: + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + else: + _compare_ignoring_var_names(original_rg._src[0], loaded_rg._src[0]) + _compare_ignoring_var_names(original_rg._src[1], loaded_rg._src[1]) # Demonstrate regridding still gives the same results. src_data = ma.arange(np.product(src.data.shape)).reshape(src.data.shape) @@ -213,14 +270,21 @@ def test_GridToMeshESMFRegridder_curvilinear_round_trip(tmp_path): # TODO: parametrize the rest of the tests in this module. +@pytest.mark.parametrize( + "regridder", + ["unstructured", ESMFAreaWeightedRegridder], +) @pytest.mark.parametrize( "rg_maker", [_make_grid_to_mesh_regridder, _make_mesh_to_grid_regridder], ids=["grid_to_mesh", "mesh_to_grid"], ) -def test_MeshESMFRegridder_masked_round_trip(tmp_path, rg_maker): +def test_MeshESMFRegridder_masked_round_trip(tmp_path, rg_maker, regridder): """Test save/load round tripping for the Mesh regridder classes.""" - original_rg, src = rg_maker(masks=True) + if regridder == "unstructured": + original_rg, src = rg_maker(masks=True) + else: + original_rg, src = rg_maker(regridder=regridder, masks=True) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) @@ -238,25 +302,34 @@ def test_MeshESMFRegridder_masked_round_trip(tmp_path, rg_maker): @pytest.mark.parametrize( - "method", + "method,regridder", [ - Constants.Method.CONSERVATIVE, - Constants.Method.BILINEAR, - Constants.Method.NEAREST, + (Constants.Method.CONSERVATIVE, MeshToGridESMFRegridder), + (Constants.Method.BILINEAR, MeshToGridESMFRegridder), + (Constants.Method.NEAREST, MeshToGridESMFRegridder), + (None, ESMFAreaWeightedRegridder), ], ) -def test_MeshToGridESMFRegridder_round_trip(tmp_path, method): - """Test save/load round tripping for `MeshToGridESMFRegridder`.""" - original_rg, src = _make_mesh_to_grid_regridder(method=method, circular=True) +def test_mesh_to_grid_round_trip(tmp_path, method, regridder): + """Test save/load round tripping for mesh to grid regridding.""" + original_rg, src = _make_mesh_to_grid_regridder( + method=method, regridder=regridder, circular=True + ) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - assert original_rg.location == loaded_rg.location + if regridder == MeshToGridESMFRegridder: + assert original_rg.location == loaded_rg.location + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + else: + assert original_rg._src.location == loaded_rg._src.location + _compare_ignoring_var_names(original_rg._tgt[0], loaded_rg._tgt[0]) + _compare_ignoring_var_names(original_rg._tgt[1], loaded_rg._tgt[1]) + assert original_rg.method == loaded_rg.method assert original_rg.mdtol == loaded_rg.mdtol - assert original_rg.grid_x == loaded_rg.grid_x - assert original_rg.grid_y == loaded_rg.grid_y # TODO: uncomment when iris mesh comparison becomes available. # assert original_rg.mesh == loaded_rg.mesh @@ -295,25 +368,52 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path, method): original_res_rg.regridder.tgt.resolution == loaded_res_rg.regridder.tgt.resolution ) + elif regridder == ESMFAreaWeightedRegridder: + assert original_rg.src_resolution == loaded_rg.src_resolution + original_res_rg, _ = _make_mesh_to_grid_regridder( + regridder=regridder, resolution=8 + ) + res_filename = tmp_path / "regridder_res.nc" + save_regridder(original_res_rg, res_filename) + loaded_res_rg = load_regridder(str(res_filename)) + assert original_res_rg.tgt_resolution == loaded_res_rg.tgt_resolution + assert ( + original_res_rg.regridder.tgt.resolution + == loaded_res_rg.regridder.tgt.resolution + ) # Ensure grid equality for non-circular coords. - original_nc_rg, _ = _make_grid_to_mesh_regridder(method=method, circular=False) + original_nc_rg, _ = _make_mesh_to_grid_regridder( + method=method, regridder=regridder, circular=False + ) nc_filename = tmp_path / "non_circular_regridder.nc" save_regridder(original_nc_rg, nc_filename) loaded_nc_rg = load_regridder(str(nc_filename)) - assert original_nc_rg.grid_x == loaded_nc_rg.grid_x - assert original_nc_rg.grid_y == loaded_nc_rg.grid_y + if regridder == MeshToGridESMFRegridder: + _compare_ignoring_var_names(original_nc_rg.grid_x, loaded_nc_rg.grid_x) + _compare_ignoring_var_names(original_nc_rg.grid_y, loaded_nc_rg.grid_y) + else: + _compare_ignoring_var_names(original_nc_rg._tgt[0], loaded_nc_rg._tgt[0]) + _compare_ignoring_var_names(original_nc_rg._tgt[1], loaded_nc_rg._tgt[1]) -def test_MeshToGridESMFRegridder_curvilinear_round_trip(tmp_path): - """Test save/load round tripping for `MeshToGridESMFRegridder`.""" - original_rg, src = _make_mesh_to_grid_regridder(grid_dims=2) +@pytest.mark.parametrize( + "regridder", + [MeshToGridESMFRegridder, ESMFAreaWeightedRegridder], +) +def test_mesh_to_grid_curvilinear_round_trip(tmp_path, regridder): + """Test save/load round tripping for mesh to grid regridding.""" + original_rg, src = _make_mesh_to_grid_regridder(regridder=regridder, grid_dims=2) filename = tmp_path / "regridder.nc" save_regridder(original_rg, filename) loaded_rg = load_regridder(str(filename)) - assert original_rg.grid_x == loaded_rg.grid_x - assert original_rg.grid_y == loaded_rg.grid_y + if regridder == MeshToGridESMFRegridder: + _compare_ignoring_var_names(original_rg.grid_x, loaded_rg.grid_x) + _compare_ignoring_var_names(original_rg.grid_y, loaded_rg.grid_y) + else: + _compare_ignoring_var_names(original_rg._tgt[0], loaded_rg._tgt[0]) + _compare_ignoring_var_names(original_rg._tgt[1], loaded_rg._tgt[1]) # Demonstrate regridding still gives the same results. src_data = ma.arange(np.product(src.data.shape)).reshape(src.data.shape) @@ -323,3 +423,128 @@ def test_MeshToGridESMFRegridder_curvilinear_round_trip(tmp_path): loaded_result = loaded_rg(src).data assert np.array_equal(original_result, loaded_result) assert np.array_equal(original_result.mask, loaded_result.mask) + + +@pytest.mark.parametrize( + "src_type,tgt_type", + [ + ("grid", "grid"), + ("grid", "mesh"), + ("mesh", "grid"), + ("mesh", "mesh"), + ], +) +@pytest.mark.parametrize( + "scheme", + [ESMFAreaWeighted, ESMFBilinear, ESMFNearest], + ids=["conservative", "linear", "nearest"], +) +def test_generic_regridder(tmp_path, src_type, tgt_type, scheme): + """Test save/load round tripping for regridders in `esmf_regrid.schemes`.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + if src_type == "grid": + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + elif src_type == "mesh": + src = _gridlike_mesh_cube(n_lons_src, n_lats_src) + if tgt_type == "grid": + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + elif tgt_type == "mesh": + tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt) + + original_rg = scheme().regridder(src, tgt) + filename = tmp_path / "regridder.nc" + save_regridder(original_rg, filename) + loaded_rg = load_regridder(str(filename)) + + if src_type == "grid": + assert original_rg._src == loaded_rg._src + if tgt_type == "grid": + assert original_rg._tgt == loaded_rg._tgt + if scheme == ESMFAreaWeighted: + assert original_rg.src_resolution == loaded_rg.src_resolution + assert original_rg.tgt_resolution == loaded_rg.tgt_resolution + assert original_rg.mdtol == loaded_rg.mdtol + + +@pytest.mark.parametrize( + "src_type,tgt_type", + [ + ("grid", "grid"), + ("grid", "mesh"), + ("mesh", "grid"), + ("mesh", "mesh"), + ], +) +@pytest.mark.parametrize( + "scheme", + [ESMFAreaWeighted, ESMFBilinear, ESMFNearest], + ids=["conservative", "linear", "nearest"], +) +def test_generic_regridder_masked(tmp_path, src_type, tgt_type, scheme): + """Test save/load round tripping for regridders in `esmf_regrid.schemes`.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + if src_type == "grid": + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + src.data = ma.array(src.data) + src.data[0, 0] = ma.masked + elif src_type == "mesh": + src = _gridlike_mesh_cube(n_lons_src, n_lats_src) + src.data = ma.array(src.data) + src.data[0] = ma.masked + if tgt_type == "grid": + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + tgt.data = ma.array(tgt.data) + tgt.data[0, 0] = ma.masked + elif tgt_type == "mesh": + tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt) + tgt.data = ma.array(tgt.data) + tgt.data[0] = ma.masked + + original_rg = scheme().regridder(src, tgt, use_src_mask=True, use_tgt_mask=True) + filename = tmp_path / "regridder.nc" + save_regridder(original_rg, filename) + loaded_rg = load_regridder(str(filename)) + + assert np.allclose(original_rg.src_mask, loaded_rg.src_mask) + assert np.allclose(original_rg.tgt_mask, loaded_rg.tgt_mask) + + +@pytest.mark.parametrize( + "scheme", + [ESMFAreaWeighted], + ids=["conservative"], +) +def test_generic_regridder_resolution(tmp_path, scheme): + """Test save/load round tripping for regridders in `esmf_regrid.schemes`.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + src_resolution = 3 + tgt_resolution = 4 + + original_rg = scheme().regridder( + src, tgt, src_resolution=src_resolution, tgt_resolution=tgt_resolution + ) + filename = tmp_path / "regridder.nc" + save_regridder(original_rg, filename) + loaded_rg = load_regridder(str(filename)) + + assert loaded_rg.src_resolution == src_resolution + assert loaded_rg.regridder.src.resolution == src_resolution + assert loaded_rg.tgt_resolution == tgt_resolution + assert loaded_rg.regridder.tgt.resolution == tgt_resolution diff --git a/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py b/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py index 234fb604..68420790 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py +++ b/esmf_regrid/tests/unit/experimental/io/test_save_regridder.py @@ -2,7 +2,11 @@ import pytest -from esmf_regrid.experimental.io import save_regridder +from esmf_regrid.experimental.io import _managed_var_name, save_regridder +from esmf_regrid.schemes import ESMFAreaWeightedRegridder +from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import ( + _gridlike_mesh_cube, +) def test_invalid_type(tmp_path): @@ -11,3 +15,48 @@ def test_invalid_type(tmp_path): filename = tmp_path / "regridder.nc" with pytest.raises(TypeError): save_regridder(invalid_obj, filename) + + +def test_var_name_preserve(tmp_path): + """Test that `save_regridder` does not change var_ames.""" + lons = 3 + lats = 4 + src = _gridlike_mesh_cube(lons, lats) + tgt = _gridlike_mesh_cube(lons, lats) + + DUMMY_VAR_NAME_SRC = "src_dummy_var" + DUMMY_VAR_NAME_TGT = "tgt_dummy_var" + for coord in src.mesh.coords(): + coord.var_name = DUMMY_VAR_NAME_SRC + for coord in tgt.mesh.coords(): + coord.var_name = DUMMY_VAR_NAME_TGT + + rg = ESMFAreaWeightedRegridder(src, tgt) + filename = tmp_path / "regridder.nc" + save_regridder(rg, filename) + + for coord in src.mesh.coords(): + assert coord.var_name == DUMMY_VAR_NAME_SRC + for coord in tgt.mesh.coords(): + assert coord.var_name == DUMMY_VAR_NAME_TGT + + +def test_managed_var_name(): + """Test that `_managed_var_name` changes var_names.""" + lons = 3 + lats = 4 + src = _gridlike_mesh_cube(lons, lats) + tgt = _gridlike_mesh_cube(lons, lats) + + DUMMY_VAR_NAME_SRC = "src_dummy_var" + DUMMY_VAR_NAME_TGT = "tgt_dummy_var" + for coord in src.mesh.coords(): + coord.var_name = DUMMY_VAR_NAME_SRC + for coord in tgt.mesh.coords(): + coord.var_name = DUMMY_VAR_NAME_TGT + + with _managed_var_name(src, tgt): + for coord in src.mesh.coords(): + assert coord.var_name != DUMMY_VAR_NAME_SRC + for coord in tgt.mesh.coords(): + assert coord.var_name != DUMMY_VAR_NAME_TGT diff --git a/esmf_regrid/tests/unit/schemes/__init__.py b/esmf_regrid/tests/unit/schemes/__init__.py index 22fc55be..c51deffd 100644 --- a/esmf_regrid/tests/unit/schemes/__init__.py +++ b/esmf_regrid/tests/unit/schemes/__init__.py @@ -1,5 +1,6 @@ """Unit tests for `esmf_regrid.schemes`.""" +import dask.array as da from iris.coord_systems import OSGB import numpy as np from numpy import ma @@ -215,3 +216,39 @@ def _test_non_degree_crs(scheme): # Check that the number of masked points is as expected. assert (1 - result.data.mask).sum() == expected_unmasked + + +def _test_dtype_handling(scheme, src_type, tgt_type, in_dtype): + """Test regridding scheme handles dtype as expected.""" + n_lons_src = 6 + n_lons_tgt = 3 + n_lats_src = 4 + n_lats_tgt = 2 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + if in_dtype == "float32": + dtype = np.float32 + elif in_dtype == "float64": + dtype = np.float64 + + if src_type == "grid": + src = _grid_cube(n_lons_src, n_lats_src, lon_bounds, lat_bounds, circular=True) + src_data = np.zeros([n_lats_src, n_lons_src], dtype=dtype) + src.data = da.array(src_data) + elif src_type == "mesh": + src = _gridlike_mesh_cube(n_lons_src, n_lats_src) + src_data = np.zeros([n_lats_src * n_lons_src], dtype=dtype) + src.data = da.array(src_data) + if tgt_type == "grid": + tgt = _grid_cube(n_lons_tgt, n_lats_tgt, lon_bounds, lat_bounds, circular=True) + elif tgt_type == "mesh": + tgt = _gridlike_mesh_cube(n_lons_tgt, n_lats_tgt) + + result = src.regrid(tgt, scheme()) + + expected_dtype = np.float64 + + assert result.has_lazy_data() + + assert result.lazy_data().dtype == expected_dtype + assert result.data.dtype == expected_dtype diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py index e73bee5d..50af91b3 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeighted.py @@ -5,6 +5,7 @@ from esmf_regrid.schemes import ESMFAreaWeighted from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, + _test_dtype_handling, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -74,3 +75,9 @@ def test_invalid_tgt_location(): def test_non_degree_crs(): """Test for coordinates with non-degree units.""" _test_non_degree_crs(ESMFAreaWeighted) + + +def test_dtype_handling(src_tgt_types, in_dtype): + """Test regridding scheme handles dtype as expected.""" + src_type, tgt_type = src_tgt_types + _test_dtype_handling(ESMFAreaWeighted, src_type, tgt_type, in_dtype) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py index c5d77ae9..8d68958e 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFAreaWeightedRegridder.py @@ -290,3 +290,28 @@ def test_masks(): weights_src_masked[:, 1:].todense(), weights_unmasked[:, 1:].todense() ) assert np.allclose(weights_tgt_masked[1:].todense(), weights_unmasked[1:].todense()) + + +def test_resolution(): + """ + Test calling of :class:`esmf_regrid.schemes.ESMFAreaWeightedRegridder`. + + Checks that the regridder accepts resolution arguments. + """ + n_lons = 6 + n_lats = 5 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + src = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + + src_resolution = 3 + tgt_resolution = 4 + + regridder = ESMFAreaWeightedRegridder( + src, tgt, src_resolution=src_resolution, tgt_resolution=tgt_resolution + ) + assert regridder.src_resolution == src_resolution + assert regridder.regridder.src.resolution == src_resolution + assert regridder.tgt_resolution == tgt_resolution + assert regridder.regridder.tgt.resolution == tgt_resolution diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py index 393bbe0a..f6a9e30f 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFBilinear.py @@ -5,6 +5,7 @@ from esmf_regrid.schemes import ESMFBilinear from esmf_regrid.tests.unit.schemes.__init__ import ( _test_cube_regrid, + _test_dtype_handling, _test_invalid_mdtol, _test_mask_from_init, _test_mask_from_regridder, @@ -63,3 +64,9 @@ def test_mask_from_regridder(mask_keyword): def test_non_degree_crs(): """Test for coordinates with non-degree units.""" _test_non_degree_crs(ESMFBilinear) + + +def test_dtype_handling(src_tgt_types, in_dtype): + """Test regridding scheme handles dtype as expected.""" + src_type, tgt_type = src_tgt_types + _test_dtype_handling(ESMFBilinear, src_type, tgt_type, in_dtype) diff --git a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py index b5812b5a..6f146e6c 100644 --- a/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py +++ b/esmf_regrid/tests/unit/schemes/test_ESMFNearest.py @@ -6,6 +6,7 @@ from esmf_regrid.schemes import ESMFNearest from esmf_regrid.tests.unit.schemes.__init__ import ( + _test_dtype_handling, _test_mask_from_init, _test_mask_from_regridder, _test_non_degree_crs, @@ -106,3 +107,9 @@ def test_mask_from_regridder(mask_keyword): def test_non_degree_crs(): """Test for coordinates with non-degree units.""" _test_non_degree_crs(ESMFNearest) + + +def test_dtype_handling(src_tgt_types, in_dtype): + """Test regridding scheme handles dtype as expected.""" + src_type, tgt_type = src_tgt_types + _test_dtype_handling(ESMFNearest, src_type, tgt_type, in_dtype) diff --git a/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py b/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py index 5670d71f..115de52a 100644 --- a/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py +++ b/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py @@ -167,7 +167,9 @@ def test_laziness(src_transposed, tgt_transposed): lat_bounds = (-90, 90) grid = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) - src_data = np.arange(n_lats * n_lons * h).reshape([n_lats, n_lons, h]) + src_data = np.arange(n_lats * n_lons * h, dtype=np.float32).reshape( + [n_lats, n_lons, h] + ) src_data = da.from_array(src_data, chunks=[3, 5, 1]) src = Cube(src_data) src.add_dim_coord(grid.coord("latitude"), 0) @@ -185,6 +187,8 @@ def test_laziness(src_transposed, tgt_transposed): assert src.has_lazy_data() result = regrid_rectilinear_to_rectilinear(src, tgt) assert result.has_lazy_data() + assert result.lazy_data().dtype == np.float64 + assert result.data.dtype == np.float64 assert np.allclose(result.data, src_data) @@ -227,7 +231,7 @@ def test_laziness_curvilinear(src_transposed, tgt_transposed): extra = AuxCoord(np.arange(e), long_name="extra dim") src_data = np.empty([h, src_lats, t, src_lons, e]) - src_data[:] = np.arange(t * h * e).reshape([h, t, e])[ + src_data[:] = np.arange(t * h * e, dtype=np.float32).reshape([h, t, e])[ :, np.newaxis, :, np.newaxis, : ] src_data_lazy = da.array(src_data) @@ -253,6 +257,8 @@ def test_laziness_curvilinear(src_transposed, tgt_transposed): result_lazy = regrid_rectilinear_to_rectilinear(src_cube_lazy, tgt_grid) assert result_lazy.has_lazy_data() + assert result.lazy_data().dtype == np.float64 + assert result.data.dtype == np.float64 assert result_lazy == result diff --git a/noxfile.py b/noxfile.py index 58b1436b..2cefbf00 100644 --- a/noxfile.py +++ b/noxfile.py @@ -5,11 +5,9 @@ """ -from datetime import datetime import os from pathlib import Path import shutil -from typing import Literal from urllib.error import HTTPError from urllib.parse import urlparse from urllib.request import urlopen @@ -313,177 +311,6 @@ def tests(session: nox.sessions.Session): session.run("pytest") -@nox.session -@nox.parametrize( - "run_type", - ["branch", "sperf", "custom"], - ids=["branch", "sperf", "custom"], -) -def benchmarks( - session: nox.sessions.Session, - run_type: Literal["overnight", "branch", "sperf", "custom"], -): - """ - Perform iris-esmf-regrid performance benchmarks (using Airspeed Velocity). - - All run types require a single Nox positional argument (e.g. - ``nox --session="foo" -- my_pos_arg``) - detailed in the parameters - section - and can optionally accept a series of further arguments that will - be added to session's ASV command. - - Parameters - ---------- - session: object - A `nox.sessions.Session` object. - run_type: {"branch", "sperf", "custom"} - * ``branch``: compares ``HEAD`` and ``HEAD``'s merge-base with the - input **base branch**. Fails if a performance regression is detected. - This is the session used by IER's CI. - * ``sperf``: Run the on-demand SPerf suite of benchmarks (part of the - UK Met Office NG-VAT project) for the ``HEAD`` of ``upstream/main`` - only, and publish the results to the input **publish directory**, - within a unique subdirectory for this run. - * ``custom``: run ASV with the input **ASV sub-command**, without any - preset arguments - must all be supplied by the user. So just like - running ASV manually, with the convenience of re-using the session's - scripted setup steps. - - Examples - -------- - * ``nox --session="benchmarks(branch)" -- upstream/main`` - * ``nox --session="benchmarks(branch)" -- upstream/mesh-data-model`` - * ``nox --session="benchmarks(branch)" -- upstream/main --bench=ci`` - * ``nox --session="benchmarks(sperf)" -- my_publish_dir - * ``nox --session="benchmarks(custom)" -- continuous a1b23d4 HEAD --quick`` - - """ - # Make sure we're not working with a list of Python versions. - if not isinstance(PY_VER, str): - message = ( - "benchmarks session requires PY_VER to be a string - representing " - f"a single Python version - instead got: {type(PY_VER)} ." - ) - raise ValueError(message) - - # The threshold beyond which shifts are 'notable'. See `asv compare`` docs - # for more. - COMPARE_FACTOR = 2.0 - - session.install("asv", "nox", "pyyaml") - session.run("conda", "install", "--yes", "conda<24.3") - - data_gen_var = "DATA_GEN_PYTHON" - if data_gen_var in os.environ: - print("Using existing data generation environment.") - data_gen_python = Path(os.environ[data_gen_var]) - else: - print("Setting up the data generation environment...") - # Get Nox to build an environment for the `tests` session, but don't - # run the session. Will re-use a cached environment if appropriate. - session.run_always( - "nox", - "--session=tests", - "--install-only", - f"--python={PY_VER}", - ) - # Find the environment built above, set it to be the data generation - # environment. - data_gen_python = next( - Path(".nox").rglob(f"tests*/bin/python{PY_VER}") - ).resolve() - session.env[data_gen_var] = data_gen_python - - print("Running ASV...") - session.cd("benchmarks") - # Skip over setup questions for a new machine. - session.run("asv", "machine", "--yes") - - # All run types require one Nox posarg. - run_type_arg = { - "branch": "base branch", - "sperf": "publish directory", - "custom": "ASV sub-command", - } - if run_type not in run_type_arg.keys(): - message = f"Unsupported run-type: {run_type}" - raise NotImplementedError(message) - if not session.posargs: - message = ( - f"Missing mandatory first Nox session posarg: " f"{run_type_arg[run_type]}" - ) - raise ValueError(message) - first_arg = session.posargs[0] - # Optional extra arguments to be passed down to ASV. - asv_args = session.posargs[1:] - - if run_type == "branch": - base_branch = first_arg - git_command = f"git merge-base HEAD {base_branch}" - merge_base = session.run(*git_command.split(" "), silent=True, external=True)[ - :8 - ] - - try: - asv_command = [ - "asv", - "continuous", - merge_base, - "HEAD", - f"--factor={COMPARE_FACTOR}", - ] - session.run(*asv_command, *asv_args) - finally: - asv_command = [ - "asv", - "compare", - merge_base, - "HEAD", - f"--factor={COMPARE_FACTOR}", - "--split", - ] - session.run(*asv_command) - - elif run_type == "sperf": - publish_dir = Path(first_arg) - if not publish_dir.is_dir(): - message = f"Input 'publish directory' is not a directory: {publish_dir}" - raise NotADirectoryError(message) - publish_subdir = ( - publish_dir / f"{run_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - ) - publish_subdir.mkdir() - - # Activate on demand benchmarks (C/SPerf are deactivated for 'standard' runs). - session.env["ON_DEMAND_BENCHMARKS"] = "True" - commit_range = "upstream/main^!" - - asv_command = [ - "asv", - "run", - commit_range, - "--bench=.*Scalability.*", - "--attribute", - "rounds=1", - ] - session.run(*asv_command, *asv_args) - - asv_command = ["asv", "publish", commit_range, f"--html-dir={publish_subdir}"] - session.run(*asv_command) - - # Print completion message. - location = Path().cwd() / ".asv" - print( - f'New ASV results for "{run_type}".\n' - f'See "{publish_subdir}",' - f'\n or JSON files under "{location / "results"}".' - ) - - else: - asv_subcommand = first_arg - assert run_type == "custom" - session.run("asv", asv_subcommand, *asv_args) - - @nox.session(python=PY_VER, venv_backend="conda") def wheel(session: nox.sessions.Session): """ diff --git a/pyproject.toml b/pyproject.toml index 98611a06..6e662229 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ include = '\.pyi?$' [tool.pytest.ini_options] addopts = "-ra -v --doctest-modules" minversion = "6.0" -doctest_optionflads = "NORMALIZE_WHITESPACE ELLIPSIS NUMBER" +doctest_optionflags = "NORMALIZE_WHITESPACE ELLIPSIS NUMBER" testpaths = "esmf_regrid" [tool.setuptools]