diff --git a/.github/workflows/branch_ci.yml b/.github/workflows/branch_ci.yml new file mode 100644 index 00000000..bb330309 --- /dev/null +++ b/.github/workflows/branch_ci.yml @@ -0,0 +1,168 @@ +# Workflow that runs on pushes to non-default branches + +name: Non-Default Branch CI (Python) + +on: + push: + branches-ignore: ["main"] + +# Specify concurrency such that only one workflow can run at a time +# * Different workflow files are not affected +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +# Registry for storing Container images +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +# Ensure the GitHub token can remove packages +permissions: + packages: write + + +jobs: + + lint-typecheck: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup uv + uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + cache-dependency-glob: "pyproject.toml" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version-file: "pyproject.toml" + + - name: Install package + run: uv sync --all-extras + + - name: Lint package + run: uv run ruff check --output-format=github . + + - name: Typecheck package + run: uv run mypy . + + test-unit: + runs-on: ubuntu-latest + needs: lint-typecheck + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup uv + uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + cache-dependency-glob: "pyproject.toml" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version-file: "pyproject.toml" + + - name: Install package + run: uv sync --all-extras + + # Run unittests + # * Produce JUnit XML report + - name: Run unit tests + run: uv run xmlrunner discover -s src/nwp_consumer -p "test_*.py" --output-file ut-report.xml + + # Create test summary to be visualised on the job summary screen on GitHub + # * Runs even if previous steps fail + - name: Create test summary + uses: test-summary/action@v2 + with: + paths: "*t-report.xml" + show: "fail, skip" + if: always() + + # Define a job that builds the documentation + # * Surfaces the documentation as an artifact + build-docs: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup uv + uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + cache-dependency-glob: "pyproject.toml" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version-file: "pyproject.toml" + + - name: Build documentation + run: uv run pydoctor + + - name: Upload documentation + uses: actions/upload-artifact@v4 + with: + name: docs + path: docs + + # * Builds and pushes an OCI Container image to the registry defined in the environment variables + # * Only runs if test job passes + build-container: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + needs: ["lint-typecheck", "test-unit"] + + steps: + # Do a non-shallow clone of the repo to ensure tags are present + # * This allows setuptools-git-versioning to automatically set the version + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Buildx + uses: docker/setup-buildx-action@v2 + + - name: Log in to the Container registry + uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # Tag the built image according to the event type + # The event is a branch commit, so use the commit sha + - name: Extract metadata (tags, labels) for Container + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + + # Build and push the Container image to the registry + # * Creates a multiplatform-aware image + # * Pulls build cache from the registry + - name: Build and push container image + uses: docker/build-push-action@v4 + with: + context: . + file: Containerfile + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + platforms: linux/amd64,linux/arm64 + cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 50116791..eb7e2d05 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -110,6 +110,7 @@ jobs: runs-on: ubuntu-latest container: quay.io/condaforge/miniforge3:latest needs: build-venv + if: github.event.name == ' steps: - name: Checkout repository diff --git a/src/nwp_consumer/cmd/main.py b/src/nwp_consumer/cmd/main.py index 016c43d9..30327474 100644 --- a/src/nwp_consumer/cmd/main.py +++ b/src/nwp_consumer/cmd/main.py @@ -1,51 +1,59 @@ """Entrypoints to the nwp-consumer service.""" -import argparse import logging import os import sys +from typing import NamedTuple -from nwp_consumer.internal import handlers, repositories, services +from nwp_consumer.internal import handlers, ports, repositories, services log = logging.getLogger("nwp-consumer") -def parse_env() -> argparse.Namespace: +class Adaptors(NamedTuple): + """Adaptors for the CLI.""" + model_repository: type[ports.ModelRepository] + notification_repository: type[ports.NotificationRepository] + +def parse_env() -> Adaptors: """Parse from the environment.""" - config = argparse.Namespace() + model_repository_adaptor: type[ports.ModelRepository] match os.getenv("MODEL_REPOSITORY"): case None: log.error("MODEL_REPOSITORY is not set in environment.") sys.exit(1) case "ceda-metoffice-global": - config.model_repository = repositories.CedaMetOfficeGlobalModelRepository() + model_repository_adaptor = repositories.CedaMetOfficeGlobalModelRepository case _ as model: log.error(f"Unknown model: {model}") sys.exit(1) + notification_repository_adaptor: type[ports.NotificationRepository] match os.getenv("NOTIFICATION_REPOSITORY", "stdout"): case "stdout": - config.notification_repository = repositories.StdoutNotificationRepository() + notification_repository_adaptor = repositories.StdoutNotificationRepository case "dagster-pipes": - config.notification_repository = repositories.DagsterPipesNotificationRepository() + notification_repository_adaptor = repositories.DagsterPipesNotificationRepository case _ as notification: log.error(f"Unknown notification repository: {notification}") sys.exit(1) - return config + return Adaptors( + model_repository=model_repository_adaptor, + notification_repository=notification_repository_adaptor, + ) def run_cli() -> None: """Entrypoint for the CLI handler.""" - args = parse_env() + adaptors = parse_env() c = handlers.CLIHandler( consumer_usecase=services.ConsumerService( - model_repository=args.model_repository, - zarr_repository=None, - notification_repository=args.notification_repository, + model_repository=adaptors.model_repository, + notification_repository=adaptors.notification_repository, ), archiver_usecase=services.ArchiverService( - model_repository=args.model_repository, - notification_repository=args.notification_repository, + model_repository=adaptors.model_repository, + notification_repository=adaptors.notification_repository, ), ) returncode: int = c.run() diff --git a/src/nwp_consumer/internal/config/__init__.py b/src/nwp_consumer/internal/config/__init__.py deleted file mode 100644 index c19df951..00000000 --- a/src/nwp_consumer/internal/config/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Configuration for the service. - -The service is configured via environment variables in accordance with the -12-factor philosophy. -""" - -from .config import AppConfig, parse - -__all__ = [ - "AppConfig", - "parse", -] diff --git a/src/nwp_consumer/internal/config/config.py b/src/nwp_consumer/internal/config/config.py deleted file mode 100644 index f564fbf2..00000000 --- a/src/nwp_consumer/internal/config/config.py +++ /dev/null @@ -1,21 +0,0 @@ -import dataclasses -import os - -import dacite - - -@dataclasses.dataclass -class AppConfig: - """App configuration.""" - LOGLEVEL: str = "INFO" - MODEL_REPOSITORY: str = "ceda-metoffice-global" - NOTIFICATION_REPOSITORY: str = "stdout" - WORKDIR: str = "~/.local/cache/nwp" - -def parse() -> AppConfig: - """Parse the configuration from the environment.""" - return dacite.from_dict( - data_class=AppConfig, - data=os.environ, - ) - diff --git a/src/nwp_consumer/internal/entities/__init__.py b/src/nwp_consumer/internal/entities/__init__.py index ded638e4..9b71b7fc 100644 --- a/src/nwp_consumer/internal/entities/__init__.py +++ b/src/nwp_consumer/internal/entities/__init__.py @@ -15,7 +15,7 @@ should not contain any logic that is specific to a particular implementation. """ -from .repometadata import ModelRepositoryMetadata, ModelFileMetadata +from .repometadata import ModelRepositoryMetadata, ModelMetadata from .tensorstore import ParameterScanResult, TensorStore from .postprocess import PostProcessOptions, CodecOptions from .notification import PerformanceMetadata, StoreCreatedNotification, StoreAppendedNotification @@ -25,7 +25,7 @@ __all__ = [ "ModelRepositoryMetadata", - "ModelFileMetadata", + "ModelMetadata", "ParameterScanResult", "TensorStore", "PostProcessOptions", diff --git a/src/nwp_consumer/internal/entities/coordinates.py b/src/nwp_consumer/internal/entities/coordinates.py index 738a2cae..853ac7d6 100644 --- a/src/nwp_consumer/internal/entities/coordinates.py +++ b/src/nwp_consumer/internal/entities/coordinates.py @@ -42,7 +42,7 @@ import pandas as pd import pytz import xarray as xr -from returns.result import Failure, Result, ResultE, Success +from returns.result import Failure, ResultE, Success from .parameters import Parameter @@ -157,7 +157,7 @@ def from_pandas( )) # Convert the pandas Index objects to lists of the appropriate types - return Result.from_value( + return Success( cls( # NOTE: The timezone information is stripped from the datetime objects # as numpy cannot handle timezone-aware datetime objects. As such, it @@ -289,7 +289,7 @@ def determine_region( """ # Ensure the inner and outer maps have the same rank and dimension labels if inner.dims != self.dims: - return Result.from_failure( + return Failure( KeyError( "Cannot find slices in non-matching coordinate mappings: " "both objects must have identical dimensions (rank and labels)." @@ -303,7 +303,7 @@ def determine_region( inner_dim_coords = getattr(inner, inner_dim_label) outer_dim_coords = getattr(self, inner_dim_label) if len(inner_dim_coords) > len(outer_dim_coords): - return Result.from_failure( + return Failure( ValueError( f"Coordinate values for dimension '{inner_dim_label}' in the inner map " "exceed the number of coordinate values in the outer map. " @@ -314,7 +314,7 @@ def determine_region( if not set(inner_dim_coords).issubset(set(outer_dim_coords)): diff_coords = list(set(inner_dim_coords).difference(set(outer_dim_coords))) first_diff_index: int = inner_dim_coords.index(diff_coords[0]) - return Result.from_failure( + return Failure( ValueError( f"Coordinate values for dimension '{inner_dim_label}' not all present " "within outer dimension map. The inner map must be entirely contained " @@ -338,7 +338,7 @@ def determine_region( # TODO: of which might loop around the edges of the grid. In this case, it would # TODO: be useful to determine if the run is non-contiguous only in that it wraps # TODO: around that boundary, and in that case, split it and write it in two goes. - return Result.from_failure( + return Failure( ValueError( f"Coordinate values for dimension '{inner_dim_label}' do not correspond " f"with a contiguous index set in the outer dimension map. " @@ -349,7 +349,7 @@ def determine_region( slices[inner_dim_label] = slice(outer_dim_indices[0], outer_dim_indices[-1] + 1) - return Result.from_value(slices) + return Success(slices) def default_chunking(self) -> dict[str, int]: """The expected chunk sizes for each dimension. diff --git a/src/nwp_consumer/internal/entities/parameters.py b/src/nwp_consumer/internal/entities/parameters.py index 97ddc44f..4bbc6126 100644 --- a/src/nwp_consumer/internal/entities/parameters.py +++ b/src/nwp_consumer/internal/entities/parameters.py @@ -109,36 +109,38 @@ class Parameter(StrEnum): CLOUD_COVER_LOW = auto() CLOUD_COVER_TOTAL = auto() TOTAL_PRECIPITATION_RATE_GL = auto() + DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL = auto() + DIRECT_SHORTWAVE_RADIATION_FLUX_GL = auto() def metadata(self) -> ParameterData: """Get the metadata for the parameter.""" - match self: - case self.TEMPERATURE_SL: + match self.name: + case self.TEMPERATURE_SL.name: return ParameterData( name=str(self), description="Temperature at screen level", units="C", limits=ParameterLimits(upper=60, lower=-90), ) - case self.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL: + case self.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.name: return ParameterData( name=str(self), description="Downward shortwave radiation flux at ground level. " - "Defined as the mean amount of solar radiation incident on the surface " - "expected over the next hour.", + "Defined as the mean amount of solar radiation " + "incident on the surface expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=1500, lower=0), ) - case self.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL: + case self.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL.name: return ParameterData( name=str(self), description="Downward longwave radiation flux at ground level. " - "Defined as the mean amount of thermal radiation incident on the surface " - "expected over the next hour.", + "Defined as the mean amount of thermal radiation " + "incident on the surface expected over the next hour.", units="W/m^2", limits=ParameterLimits(upper=500, lower=0), ) - case self.RELATIVE_HUMIDITY_SL: + case self.RELATIVE_HUMIDITY_SL.name: return ParameterData( name=str(self), description="Relative humidity at screen level. " @@ -147,7 +149,7 @@ def metadata(self) -> ParameterData: units="%", limits=ParameterLimits(upper=100, lower=0), ) - case self.VISIBILITY_SL: + case self.VISIBILITY_SL.name: return ParameterData( name=str(self), description="Visibility at screen level. " @@ -156,99 +158,109 @@ def metadata(self) -> ParameterData: units="m", limits=ParameterLimits(upper=4500, lower=0), ) - case self.WIND_U_COMPONENT_10m: + case self.WIND_U_COMPONENT_10m.name: return ParameterData( name=str(self), description="U component of wind at 10m above ground level. " - "Defined as the horizontal speed of the wind in the eastward direction.", + "Defined as the horizontal speed of " + "the wind in the eastward direction.", units="m/s", limits=ParameterLimits(upper=100, lower=-100), ) - case self.WIND_V_COMPONENT_10m: + case self.WIND_V_COMPONENT_10m.name: return ParameterData( name=str(self), description="V component of wind at 10m above ground level. " - "Defined as the horizontal speed of the wind in the northward direction.", + "Defined as the horizontal speed of " + "the wind in the northward direction.", units="m/s", # Non-tornadic winds are usually < 100m/s limits=ParameterLimits(upper=100, lower=-100), ) - case self.WIND_U_COMPONENT_100m: + case self.WIND_U_COMPONENT_100m.name: return ParameterData( name=str(self), description="U component of wind at 100m above ground level. " - "Defined as the horizontal speed of the wind in the eastward direction.", + "Defined as the horizontal speed of " + "the wind in the eastward direction.", units="m/s", limits=ParameterLimits(upper=100, lower=-100), ) - case self.WIND_V_COMPONENT_100m: + case self.WIND_V_COMPONENT_100m.name: return ParameterData( name=str(self), description="V component of wind at 100m above ground level. " - "Defined as the horizontal speed of the wind in the northward direction.", + "Defined as the horizontal speed of " + "the wind in the northward direction.", units="m/s", limits=ParameterLimits(upper=100, lower=-100), ) - case self.WIND_U_COMPONENT_200m: + case self.WIND_U_COMPONENT_200m.name: return ParameterData( name=str(self), description="U component of wind at 200m above ground level. " - "Defined as the horizontal speed of the wind in the eastward direction.", + "Defined as the horizontal speed of " + "the wind in the eastward direction.", units="m/s", limits=ParameterLimits(upper=150, lower=-150), ) - case self.WIND_V_COMPONENT_200m: + case self.WIND_V_COMPONENT_200m.name: return ParameterData( name=str(self), description="V component of wind at 200m above ground level. " - "Defined as the horizontal speed of the wind in the northward direction.", + "Defined as the horizontal speed of " + "the wind in the northward direction.", units="m/s", limits=ParameterLimits(upper=150, lower=-150), ) - case self.SNOW_DEPTH_GL: + case self.SNOW_DEPTH_GL.name: return ParameterData( name=str(self), description="Depth of snow on the ground.", units="m", limits=ParameterLimits(upper=12, lower=0), ) - case self.CLOUD_COVER_HIGH: + case self.CLOUD_COVER_HIGH.name: return ParameterData( name=str(self), description="Fraction of grid square covered by high-level cloud. " - "Defined as the ratio of the area of the grid square covered by high-level " - "(>6km) cloud to the square's total area.", + "Defined as the ratio of " + "the area of the grid square covered by high-level (>6km) cloud " + "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), ) - case self.CLOUD_COVER_MEDIUM: + case self.CLOUD_COVER_MEDIUM.name: return ParameterData( name=str(self), description="Fraction of grid square covered by medium-level cloud. " - "Defined as the ratio of the area of the grid square covered by medium-level " - "(2-6km) cloud to the square's total area.", + "Defined as the ratio of " + "the area of the grid square covered by medium-level (2-6km) cloud " + "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), ) - case self.CLOUD_COVER_LOW: + case self.CLOUD_COVER_LOW.name: return ParameterData( name=str(self), description="Fraction of grid square covered by low-level cloud. " - "Defined as the ratio of the area of the grid square covered by low-level " - "(<2km) cloud to the square's total area.", + "Defined as the ratio of " + "the area of the grid square covered by low-level (<2km) cloud " + "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), ) - case self.CLOUD_COVER_TOTAL: + case self.CLOUD_COVER_TOTAL.name: return ParameterData( name=str(self), description="Fraction of grid square covered by any cloud. " - "Defined as the ratio of the area of the grid square covered by any " - "cloud to the square's total area.", + "Defined as the ratio of " + "the area of the grid square covered by any cloud " + "to the square's total area.", units="UI", limits=ParameterLimits(upper=1, lower=0), ) - case self.TOTAL_PRECIPITATION_RATE_GL: + case self.TOTAL_PRECIPITATION_RATE_GL.name: return ParameterData( name=str(self), description="Total precipitation rate at ground level. " @@ -257,6 +269,27 @@ def metadata(self) -> ParameterData: units="kg/m^2/s", limits=ParameterLimits(upper=0.2, lower=0), ) + case self.DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL.name: + return ParameterData( + name=str(self), + description="Downward ultraviolet radiation flux at ground level. " + "Defined as the mean amount of " + "ultraviolet radiation incident on the surface " + "expected over the next hour.", + units="W/m^2", + limits=ParameterLimits(upper=1000, lower=0), + ) + case self.DIRECT_SHORTWAVE_RADIATION_FLUX_GL.name: + return ParameterData( + name=str(self), + description="Direct shortwave radiation flux at ground level. " + "Defined as the mean amount of " + "unscattered solar radiation incident on" + "a surface plane perpendicular to the direction of the sun " + "expected over the next hour.", + units="W/m^2", + limits=ParameterLimits(upper=1000, lower=0), + ) case _: # Shouldn't happen thanks to the test case in test_parameters.py raise ValueError(f"Unknown parameter: {self}") diff --git a/src/nwp_consumer/internal/entities/repometadata.py b/src/nwp_consumer/internal/entities/repometadata.py index 47587b8a..ab3a9b6b 100644 --- a/src/nwp_consumer/internal/entities/repometadata.py +++ b/src/nwp_consumer/internal/entities/repometadata.py @@ -6,11 +6,15 @@ it provides. This might be helpful in determining the quality of the data, defining pipelines for processing, or establishing the availability for a live service. + +In this instance, the `ModelMetadata` refers to information pertaining +to the model used to generate the data itself, whilst the +`ModelRepositoryMetadata` refers to information about the repository +where NWP data produced by the model resides. """ import dataclasses import datetime as dt -import pathlib import pandas as pd @@ -19,15 +23,61 @@ @dataclasses.dataclass(slots=True) -class ModelRepositoryMetadata: - """Metadata for an NWP Model repository.""" +class ModelMetadata: + """Metadata for an NWP model.""" name: str """The name of the model. - Also used to name the tensor in the zarr store. + Used to name the tensor in the zarr store. + """ + + resolution: str + """The resolution of the model with units.""" + + expected_coordinates: NWPDimensionCoordinateMap + """The expected dimension coordinate mapping. + + This is a dictionary mapping dimension labels to their coordinate values, + for a single init time dataset, e.g. + + >>> { + >>> "init_time": [dt.datetime(2021, 1, 1, 0, 0), ...], + >>> "step": [1, 2, ...], + >>> "latitude": [90, 89.75, 89.5, ...], + >>> "longitude": [180, 179, ...], + >>> } + + To work this out, it can be useful to use the 'grib_ls' tool from eccodes: + + >>> grib_ls -n geography -wcount=13 raw_file.grib + + Which prints grid data from the grib file. """ + def __str__(self) -> str: + """Return a pretty-printed string representation of the metadata.""" + pretty: str = "".join(( + "Model:", + "\n\t{self.name} ({self.resolution} resolution)", + "\tCoordinates:", + "\n".join( + f"\t\t{dim}: {vals}" + if len(vals) < 5 + else f"\t\t{dim}: {vals[:3]} ... {vals[-3:]}" + for dim, vals in self.expected_coordinates.__dict__.items() + ), + )) + return pretty + + +@dataclasses.dataclass(slots=True) +class ModelRepositoryMetadata: + """Metadata for an NWP Model repository.""" + + name: str + """The name of the model repository.""" + is_archive: bool """Whether the repository is a complete archival set. @@ -66,26 +116,6 @@ class ModelRepositoryMetadata: downloading data from the repository. """ - expected_coordinates: NWPDimensionCoordinateMap - """The expected dimension coordinate mapping. - - This is a dictionary mapping dimension labels to their coordinate values, - for a single init time dataset, e.g. - - >>> { - >>> "init_time": [dt.datetime(2021, 1, 1, 0, 0), ...], - >>> "step": [1, 2, ...], - >>> "latitude": [90, 89.75, 89.5, ...], - >>> "longitude": [180, 179, ...], - >>> } - - To work this out, it can be useful to use the 'grib_ls' tool from eccodes: - - >>> grib_ls -n geography -wcount=13 raw_file.grib - - Which prints grid data from the grib file. - """ - postprocess_options: PostProcessOptions """Options for post-processing the data.""" @@ -116,48 +146,15 @@ def month_its(self, year: int, month: int) -> list[dt.datetime]: def __str__(self) -> str: """Return a pretty-printed string representation of the metadata.""" - pretty: str = "\n".join(( - f"Model: {self.name} ({'archive' if self.is_archive else 'live/rolling'} dataset.)", - f"\truns at: {self.running_hours} hours (available after {self.delay_minutes} minute delay)", - "\tCoordinates:", - "\n".join(f"\t\t{dim}: {vals}" - if len(vals) < 5 - else f"\t\t{dim}: {vals[:3]} ... {vals[-3:]}" - for dim, vals in self.expected_coordinates.__dict__.items() - ), - "Environment variables:", - "\tRequired:", + pretty: str = "".join(( + "Model Repository: ", + f"\n\t{self.name} ({'archive' if self.is_archive else 'live/rolling'} dataset.)", + f"\n\truns at: {self.running_hours} hours ", + "(available after {self.delay_minutes} minute delay)", + "\nEnvironment variables:", + "\n\tRequired:", "\n".join(f"\t\t{var}" for var in self.required_env), - "\tOptional:", + "\n\tOptional:", "\n".join(f"\t\t{var}={val}" for var, val in self.optional_env.items()), )) return pretty - -@dataclasses.dataclass(slots=True) -class ModelFileMetadata: - """Metadata for a raw file.""" - - name: str - """The name of the file.""" - - path: pathlib.Path - """The relevant (remote or local) path to the file.""" - - scheme: str - """The scheme of the path (e.g. 'https', 'ftp', 'file').""" - - extension: str - """The file extension, including the dot (e.g. '.grib').""" - - size_bytes: int - """The size of the file in bytes.""" - - parameters: list[str] - """The parameters within the file.""" - - steps: list[int] - """The steps contained in the file.""" - - init_time: dt.datetime - """The init time the file data corresponds to.""" - diff --git a/src/nwp_consumer/internal/entities/tensorstore.py b/src/nwp_consumer/internal/entities/tensorstore.py index c81c66a5..7f3946c6 100644 --- a/src/nwp_consumer/internal/entities/tensorstore.py +++ b/src/nwp_consumer/internal/entities/tensorstore.py @@ -22,11 +22,10 @@ import pandas as pd import xarray as xr import zarr -from returns.iterables import Fold -from returns.result import Failure, Result, ResultE, Success +from returns.result import Failure, ResultE, Success from .coordinates import NWPDimensionCoordinateMap -from .parameters import Parameter, ParameterData +from .parameters import Parameter from .postprocess import PostProcessOptions log = logging.getLogger("nwp-consumer") @@ -128,7 +127,7 @@ def initialize_empty_store( A new instance of the TensorStore class. """ if not isinstance(coords.init_time, list) or len(coords.init_time) == 0: - return Result.from_failure( + return Failure( ValueError( "Cannot initialize store with 'init_time' dimension coordinates not " "specified via a populated list. Check instantiation of " @@ -158,7 +157,7 @@ def initialize_empty_store( } # Create a dask array of zeros with the shape of the dataset # * The values of this are ignored, only the shape and chunks are used - dummy_values = dask.array.zeros( + dummy_values = dask.array.zeros( # type: ignore shape=list(coords.shapemap.values()), chunks=tuple([intermediate_chunks[k] for k in coords.shapemap]), ) @@ -191,7 +190,7 @@ def initialize_empty_store( store_da: xr.DataArray = xr.open_dataarray(store_path, engine="zarr") for dim in store_da.dims: if dim not in da.dims: - return Result.from_failure( + return Failure( ValueError( "Cannot use existing store due to mismatched coordinates. " f"Dimension '{dim}' in existing store not found in new store. " @@ -200,7 +199,7 @@ def initialize_empty_store( ), ) if not np.array_equal(store_da.coords[dim].values, da.coords[dim].values): - return Result.from_failure( + return Failure( ValueError( "Cannot use existing store due to mismatched coordinates. " f"Dimension '{dim}' in existing store has different coordinate " @@ -224,7 +223,7 @@ def initialize_empty_store( # Ensure the store is readable store_da = xr.open_dataarray(store_path, engine="zarr") except Exception as e: - return Result.from_failure( + return Failure( OSError( f"Failed writing blank store to disk: {e}", ), @@ -232,14 +231,14 @@ def initialize_empty_store( # Check the resultant array's coordinates can be converted back coordinate_map_result = NWPDimensionCoordinateMap.from_xarray(store_da) if isinstance(coordinate_map_result, Failure): - return Result.from_failure( + return Failure( OSError( f"Error reading back coordinates of initialized store " f"from disk (possible corruption): {coordinate_map_result}", ), ) - return Result.from_value( + return Success( cls( name=name, path=store_path, @@ -276,14 +275,14 @@ def write_to_region( self.coordinate_map.determine_region, ) if isinstance(region_result, Failure): - return Result.from_failure(region_result.failure()) + return Failure(region_result.failure()) region = region_result.unwrap() # Perform the regional write try: da.to_zarr(store=self.path, region=region, consolidated=True) except Exception as e: - return Result.from_failure( + return Failure( OSError( f"Error writing to region of store: {e}", ), @@ -293,7 +292,7 @@ def write_to_region( nbytes: int = da.nbytes del da self.size_mb += nbytes // (1024**2) - return Result.from_value(nbytes) + return Success(nbytes) def validate_store(self) -> ResultE[bool]: """Validate the store. @@ -308,29 +307,29 @@ def validate_store(self) -> ResultE[bool]: coords_result = NWPDimensionCoordinateMap.from_xarray(store_da) match coords_result: case Failure(e): - return Result.from_failure(e) + return Failure(e) case Success(coords): if coords != self.coordinate_map: - return Result.from_failure( - ValueError( - "Coordinate consistency check failed: " - "Store coordinates do not match expected coordinates. " - f"Expected: {self.coordinate_map}. Got: {coords}.", - ), - ) + return Failure(ValueError( + "Coordinate consistency check failed: " + "Store coordinates do not match expected coordinates. " + f"Expected: {self.coordinate_map}. Got: {coords}.", + )) # Validity check on the parameters of the store for param in self.coordinate_map.variable: - scan_result: ResultE[ParameterScanResult] = self.scan_parameter_values(p=param.metadata()) + scan_result: ResultE[ParameterScanResult] = self.scan_parameter_values(p=param) match scan_result: case Failure(e): - return Result.from_failure(e) + return Failure(e) case Success(scan): log.debug(f"Scanned parameter {param.name}: {scan.__repr__()}") if not scan.is_valid or scan.has_nulls: - return Result.from_value(False) + return Success(False) + + return Success(True) - def scan_parameter_values(self, p: ParameterData) -> ResultE[ParameterScanResult]: + def scan_parameter_values(self, p: Parameter) -> ResultE[ParameterScanResult]: """Scan the values of a parameter in the store. Extracts data from the values of the given parameter in the store. @@ -344,16 +343,14 @@ def scan_parameter_values(self, p: ParameterData) -> ResultE[ParameterScanResult A ParameterScanResult object. """ if p not in self.coordinate_map.variable: - return Result.from_failure( - KeyError( - "Parameter scan failed: " - f"Cannot validate unknown parameter: {p.name}. " - "Ensure the parameter has been renamed to match the entities " - "parameters defined in `entities.parameters` if desired, or " - "add the parameter to the entities parameters if it is new. " - f"Store parameters: {[p.name for p in self.coordinate_map.variable]}.", - ), - ) + return Failure(KeyError( + "Parameter scan failed: " + f"Cannot validate unknown parameter: {p.name}. " + "Ensure the parameter has been renamed to match the entities " + "parameters defined in `entities.parameters` if desired, or " + "add the parameter to the entities parameters if it is new. " + f"Store parameters: {[p.name for p in self.coordinate_map.variable]}.", + )) store_da: xr.DataArray = xr.open_dataarray(self.path, engine="zarr") # Calculating the mean of a dataarray returns another dataarray, so it @@ -362,7 +359,7 @@ def scan_parameter_values(self, p: ParameterData) -> ResultE[ParameterScanResult # second call to `mean()` helps to reassure them its a float. mean = store_da.mean().values.mean() - return Result.from_value( + return Success( ParameterScanResult( mean=mean, is_valid=True, @@ -410,7 +407,7 @@ def postprocess(self, options: PostProcessOptions) -> ResultE[pathlib.Path]: coordinates_result = NWPDimensionCoordinateMap.from_xarray(store_da) match coordinates_result: case Failure(e): - return Result.from_failure(e) + return Failure(e) case Success(coords): self.coordinate_map = coords @@ -426,7 +423,7 @@ def postprocess(self, options: PostProcessOptions) -> ResultE[pathlib.Path]: # * See https://github.com/sgkit-dev/sgkit/issues/991 # * and https://github.com/pydata/xarray/issues/3476 store_da.coords["variable"].encoding.clear() - zstore: xr.backends.zarr.ZarrStore = store_da.to_zarr( + _ = store_da.to_zarr( store=processed_path, mode="w", encoding=self.encoding, @@ -434,7 +431,7 @@ def postprocess(self, options: PostProcessOptions) -> ResultE[pathlib.Path]: ) self.path = processed_path except Exception as e: - return Result.from_failure( + return Failure( OSError( f"Error encountered writing postprocessed store: {e}", ), @@ -448,17 +445,17 @@ def postprocess(self, options: PostProcessOptions) -> ResultE[pathlib.Path]: try: shutil.make_archive(self.path.name, "zip", self.path) except Exception as e: - return Result.from_failure( + return Failure( OSError( f"Error encountered zipping store: {e}", ), ) log.debug("Postprocessing complete for store %s", self.name) - return Result.from_value(self.path) + return Success(self.path) else: - return Result.from_value(self.path) + return Success(self.path) def update_attrs(self, attrs: dict[str, str]) -> ResultE[pathlib.Path]: """Update the attributes of the store. @@ -468,7 +465,7 @@ def update_attrs(self, attrs: dict[str, str]) -> ResultE[pathlib.Path]: group: zarr.Group = zarr.open_group(self.path.as_posix()) group.attrs.update(attrs) zarr.consolidate_metadata(self.path.as_posix()) - return Result.from_value(self.path) + return Success(self.path) def missing_times(self) -> ResultE[list[dt.datetime]]: """Find the missing init_time in the store. @@ -480,7 +477,7 @@ def missing_times(self) -> ResultE[list[dt.datetime]]: try: store_da: xr.DataArray = xr.open_dataarray(self.path, engine="zarr") except Exception as e: - return Result.from_failure(OSError( + return Failure(OSError( "Cannot determine missing times in store due to " f"error reading '{self.path}': {e}", )) @@ -491,6 +488,6 @@ def missing_times(self) -> ResultE[list[dt.datetime]]: if d != "init_time" }).isnull().all().values: missing_times.append(pd.Timestamp(it).to_pydatetime().replace(tzinfo=dt.UTC)) - return Result.from_value(missing_times) + return Success(missing_times) diff --git a/src/nwp_consumer/internal/entities/test_coordinates.py b/src/nwp_consumer/internal/entities/test_coordinates.py index 087a9f7f..0f2b10c0 100644 --- a/src/nwp_consumer/internal/entities/test_coordinates.py +++ b/src/nwp_consumer/internal/entities/test_coordinates.py @@ -91,7 +91,7 @@ class TestCase: init_time=outer.init_time[:1], step=[15], variable=outer.variable, - latitude=[*outer.latitude, 64.0], + latitude=[12, 13, 14, 15], longitude=outer.longitude, ), expected_slices={}, @@ -126,7 +126,7 @@ def test_to_pandas(self) -> None: class TestCase: name: str coords: NWPDimensionCoordinateMap - expected_indexes: dict[str, pd.Index] | None + expected_indexes: dict[str, pd.Index] # type: ignore tests = [ TestCase( @@ -174,7 +174,7 @@ def test_from_pandas(self) -> None: @dataclasses.dataclass class TestCase: name: str - data: dict[str, pd.Index] + data: dict[str, pd.Index] # type: ignore expected_coordinates: NWPDimensionCoordinateMap | None should_error: bool diff --git a/src/nwp_consumer/internal/entities/test_parameters.py b/src/nwp_consumer/internal/entities/test_parameters.py index 4d18eb8e..1470e28f 100644 --- a/src/nwp_consumer/internal/entities/test_parameters.py +++ b/src/nwp_consumer/internal/entities/test_parameters.py @@ -10,7 +10,7 @@ class TestParameters(unittest.TestCase): """Test the business methods of the Parameters class.""" @given(st.sampled_from(Parameter)) - def test_metadata(self, p: Parameter): + def test_metadata(self, p: Parameter) -> None: """Test the metadata method.""" metadata = p.metadata() self.assertEqual(metadata.name, p.value) diff --git a/src/nwp_consumer/internal/entities/test_repometadata.py b/src/nwp_consumer/internal/entities/test_repometadata.py index b29584a2..e2287015 100644 --- a/src/nwp_consumer/internal/entities/test_repometadata.py +++ b/src/nwp_consumer/internal/entities/test_repometadata.py @@ -2,9 +2,8 @@ import datetime as dt import unittest -from . import NWPDimensionCoordinateMap -from .repometadata import ModelRepositoryMetadata from .postprocess import PostProcessOptions +from .repometadata import ModelRepositoryMetadata class TestModelRepositoryMetadata(unittest.TestCase): @@ -18,11 +17,6 @@ class TestModelRepositoryMetadata(unittest.TestCase): delay_minutes=60, required_env=["TEST"], optional_env={"TEST": "test"}, - expected_coordinates=NWPDimensionCoordinateMap( - init_time=[dt.datetime(2021, 1, 1, tzinfo=dt.UTC)], - step=[1, 2], - variable=[], - ), max_connections=1, postprocess_options=PostProcessOptions(), ) diff --git a/src/nwp_consumer/internal/ports/repositories.py b/src/nwp_consumer/internal/ports/repositories.py index 5babf27b..6bfe5983 100644 --- a/src/nwp_consumer/internal/ports/repositories.py +++ b/src/nwp_consumer/internal/ports/repositories.py @@ -1,13 +1,11 @@ """Repository interfaces for NWP data sources and stores. These interfaces define the signatures that *driven* actors must conform to -in order to interact with the core. These interfaces include providers of -NWP data (`ModelRepository`) and stores for processed data (`ZarrRepository`). - +in order to interact with the core. Also sometimes referred to as *secondary ports*. -All NWP providers use some kind of model_repositories to generate their data. This model_repositories -can be physically based, such as ERA5, or a machine learning model_repositories, such as +All NWP providers use some kind of model to generate their data. This repository +can be physics-based, such as ERA5, or a machine learning model_repositories, such as Google's GraphCast. The `ModelRepository` interface is used to abstract the differences between these models, allowing the core to interact with them in a uniform way. @@ -17,6 +15,7 @@ import datetime as dt import pathlib from collections.abc import Callable, Iterator + import xarray as xr from returns.result import ResultE @@ -41,6 +40,13 @@ class ModelRepository(abc.ABC): - the *store*: The Zarr store containing the processed data """ + @classmethod + @abc.abstractmethod + def authenticate(cls) -> ResultE["ModelRepository"]: + """Create a new authenticated instance of the class.""" + pass + + @abc.abstractmethod def fetch_init_data(self, it: dt.datetime) \ -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: @@ -74,7 +80,7 @@ def fetch_init_data(self, it: dt.datetime) \ ... ... def _download_and_convert(self, file: str) -> ResultE[list[xr.DataArray]]: ... '''Download and convert a raw file to an xarray dataset.''' - ... return Result.from_value([xr.open_dataset(file).to_dataarray()]) + ... return Success([xr.open_dataset(file).to_dataarray()]) .. important:: No downloading or processing should be done in this method*. All of that should be handled in the function that is yielded by the generator - @@ -104,7 +110,7 @@ def fetch_init_data(self, it: dt.datetime) \ @staticmethod @abc.abstractmethod - def metadata() -> entities.ModelRepositoryMetadata: + def repository() -> entities.ModelRepositoryMetadata: """Metadata about the model repository. See Also: @@ -112,6 +118,15 @@ def metadata() -> entities.ModelRepositoryMetadata: """ pass + @staticmethod + @abc.abstractmethod + def model() -> entities.ModelMetadata: + """Metadata about the model. + + See Also: + - `entities.ModelMetadata`. + """ + class ZarrRepository(abc.ABC): """Interface for a repository that stores Zarr NWP data.""" diff --git a/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py b/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py index 28201f51..c67ec0de 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/ecmwf_realtime.py @@ -1,21 +1,21 @@ """Model repository implementation for ECMWF live data from S3. When getting live or realtime data from ECMWF, grib files are sent by -your data provider to a location of choice, in this case an S3 bucket. +a data provider to a location of choice, in this case an S3 bucket. """ import datetime as dt import logging import os import pathlib -from collections.abc import Callable, Collection, Iterator +from collections.abc import Callable, Iterator from typing import override import cfgrib import s3fs import xarray as xr from joblib import delayed -from returns.result import Failure, Result, ResultE, Success +from returns.result import Failure, ResultE, Success from nwp_consumer.internal import entities, ports @@ -36,9 +36,9 @@ def __init__(self, bucket: str, fs: s3fs.S3FileSystem) -> None: @staticmethod @override - def metadata() -> entities.ModelRepositoryMetadata: + def repository() -> entities.ModelRepositoryMetadata: return entities.ModelRepositoryMetadata( - name="ecmwf_realtime_operational_uk_11km", + name="ECMWF-Realtime-S3", is_archive=False, is_order_based=True, running_hours=[0, 12], @@ -50,7 +50,19 @@ def metadata() -> entities.ModelRepositoryMetadata: "ECMWF_REALTIME_S3_BUCKET", "ECMWF_REALTIME_S3_REGION", ], - optional_env={}, + optional_env={ + "ECMWF_REALTIME_DISSEMINATION_FILE_PREFIX": "A2", + "ECMWF_REALTIME_S3_BUCKET_PREFIX": "ecmwf", + }, + postprocess_options=entities.PostProcessOptions(), + ) + + @staticmethod + @override + def model() -> entities.ModelMetadata: + return entities.ModelMetadata( + name="HRES-IFS", + resolution="0.1 degrees", expected_coordinates=entities.NWPDimensionCoordinateMap( init_time=[], step=list(range(0, 84, 1)), @@ -71,10 +83,9 @@ def metadata() -> entities.ModelRepositoryMetadata: entities.Parameter.SNOW_DEPTH_GL, entities.Parameter.VISIBILITY_SL, ], - latitude=[float(f"{lat/10:.2f}") for lat in range(900, -900 - 1, -1)], - longitude=[float(f"{lon/10:.2f}") for lon in range(-1800, 1800 + 1, 1)], + latitude=[float(f"{lat / 10:.2f}") for lat in range(900, -900 - 1, -1)], + longitude=[float(f"{lon / 10:.2f}") for lon in range(-1800, 1800 + 1, 1)], ), - postprocess_options=entities.PostProcessOptions(), ) @override @@ -107,12 +118,8 @@ def fetch_init_data(self, it: dt.datetime) \ yield delayed(self._download_and_convert)(url=url) @classmethod + @override def authenticate(cls) -> ResultE["ECMWFRealTimeS3ModelRepository"]: - """Authenticate with the S3 bucket. - - Returns: - The authenticated S3 filesystem object. - """ try: bucket: str = os.environ["ECMWF_REALTIME_S3_BUCKET"] _fs: s3fs.S3FileSystem = s3fs.S3FileSystem( @@ -132,9 +139,13 @@ def authenticate(cls) -> ResultE["ECMWFRealTimeS3ModelRepository"]: return Success(cls(bucket=bucket, fs=_fs)) - def _download_and_convert(self, url: str) -> ResultE[Collection[xr.DataArray]]: - # TODO - pass + def _download_and_convert(self, url: str) -> ResultE[list[xr.DataArray]]: + """Download and convert a file to xarray DataArrays. + + Args: + url: The URL of the file to download. + """ + return self._download(url=url).bind(self._convert) def _download(self, url: str) -> ResultE[pathlib.Path]: """Download an ECMWF realtime file from S3. @@ -150,7 +161,10 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: local_path: pathlib.Path = ( pathlib.Path( - os.getenv("RAWDIR", f"~/.local/cache/nwp/{self.metadata().name}/raw"), + os.getenv( + "RAWDIR", + f"~/.local/cache/nwp/{self.repository().name}/{self.model().name}/raw", + ), ) / url.split("/")[-1] ) @@ -174,7 +188,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: f"Failed to download file from S3 at '{url}'. Encountered error: {e}", )) - return Result.from_value(local_path) + return Success(local_path) @staticmethod def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: @@ -186,10 +200,68 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: try: dss: list[xr.Dataset] = cfgrib.open_datasets(path.as_posix()) except Exception as e: - return Result.from_failure(OSError(f"Error opening '{path}' as xarray Dataset: {e}")) - # TODO: Rename the variables to match the expected names - pass - + return Failure(OSError( + f"Error opening '{path}' as list of xarray Datasets: {e}", + )) + if len(dss) == 0: + return Failure(ValueError(f"No datasets found in '{path}'")) + processed_das: list[xr.DataArray] = [] + for i, ds in enumerate(dss): + try: + da = xr.DataArray( + ds.drop_vars( + names=[ + v for v in ds.coords + if v not in ["time", "step", "latitude", "longitude"] + ], + errors="ignore", + + ) + .rename({"time": "init_time"}) + .expand_dims("init_time") + .expand_dims("step") + .pipe(ECMWFRealTimeS3ModelRepository._rename_vars) + .to_dataarray(name=ECMWFRealTimeS3ModelRepository.repository().name) + .transpose("init_time", "step", "latitude", "longitude") + .sortby("step"), + ) + except Exception as e: + return Failure(ValueError( + f"Error processing dataset {i} from '{path}' to DataArray: {e}", + )) + processed_das.append(da) + del ds[i] + return Success(processed_das) + pass + @staticmethod + def _rename_vars(ds: xr.Dataset) -> xr.Dataset: + """Rename variables to match the expected names.""" + rename_map: dict[str, str] = { + "dsrp": entities.Parameter.DIRECT_SHORTWAVE_RADIATION_FLUX_GL.value, + "uvb": entities.Parameter.DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL, + "sd": entities.Parameter.SNOW_DEPTH_GL.value, + "tcc": entities.Parameter.CLOUD_COVER_TOTAL.value, + "clt": entities.Parameter.CLOUD_COVER_TOTAL.value, + "u10": entities.Parameter.WIND_U_COMPONENT_10m.value, + "v10": entities.Parameter.WIND_V_COMPONENT_10m.value, + "t2m": entities.Parameter.TEMPERATURE_SL.value, + "ssrd": entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.value, + "strd": entities.Parameter.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL.value, + "lcc": entities.Parameter.CLOUD_COVER_LOW.value, + "mcc": entities.Parameter.CLOUD_COVER_MEDIUM.value, + "hcc": entities.Parameter.CLOUD_COVER_HIGH.value, + "vis": entities.Parameter.VISIBILITY_SL.value, + "u200": entities.Parameter.WIND_U_COMPONENT_200m.value, + "v200": entities.Parameter.WIND_V_COMPONENT_200m.value, + "u100": entities.Parameter.WIND_U_COMPONENT_100m.value, + "v100": entities.Parameter.WIND_V_COMPONENT_100m.value, + "tprate": entities.Parameter.TOTAL_PRECIPITATION_RATE_GL.value, + } + + for old, new in rename_map.items(): + if old in ds.data_vars: + ds = ds.rename({old: new}) + return ds diff --git a/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py b/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py index ffabdf93..8b38eabc 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/metoffice_global.py @@ -10,7 +10,8 @@ and the spec sheet from the Met Office is detailed in `this PDF `_. -For further details on the repository, see the `CedaMetOfficeGlobalModelRepository.metadata` implementation. +For further details on the repository, see the +`CedaMetOfficeGlobalModelRepository.metadata` implementation. Data discrepancies and corrections ================================== @@ -108,9 +109,9 @@ def __init__(self, url_auth: str) -> None: @staticmethod @override - def metadata() -> entities.ModelRepositoryMetadata: + def repository() -> entities.ModelRepositoryMetadata: return entities.ModelRepositoryMetadata( - name="ceda_metoffice_global_17km", + name="CEDA", is_archive=True, is_order_based=False, running_hours=[0, 12], # 6 and 18 exist, but are lacking variables @@ -118,7 +119,18 @@ def metadata() -> entities.ModelRepositoryMetadata: max_connections=20, required_env=["CEDA_FTP_USER", "CEDA_FTP_PASS"], optional_env={}, - expected_coordinates=entities.NWPDimensionCoordinateMap( + postprocess_options=entities.PostProcessOptions( + standardize_coordinates=True, + ), + ) + + @staticmethod + @override + def model() -> entities.ModelMetadata: + return entities.ModelMetadata( + name="UM-Global", + resolution="17km", + expected_coordinates = entities.NWPDimensionCoordinateMap( init_time=[], step=list(range(0, 48, 1)), variable=[ @@ -146,9 +158,6 @@ def metadata() -> entities.ModelRepositoryMetadata: ]) ], ), - postprocess_options=entities.PostProcessOptions( - standardize_coordinates=True, - ), ) @override @@ -185,28 +194,25 @@ def fetch_init_data(self, it: dt.datetime) \ pass - def _download_and_convert(self, url: str) \ - -> ResultE[list[xr.DataArray]]: - """Download and convert a file to an xarray dataset. + def _download_and_convert(self, url: str) -> ResultE[list[xr.DataArray]]: + """Download and convert a file to xarray DataArrays. Args: url: The URL of the file to download. - - Returns: - A ResultE containing the xarray dataset. """ return self._download(url).bind(self._convert) @classmethod + @override def authenticate(cls) -> ResultE["CedaMetOfficeGlobalModelRepository"]: """Authenticate with the CEDA FTP server. Returns: A Result containing the instantiated class if successful, or an error if not. """ - if all(k not in os.environ for k in cls.metadata().required_env): + if all(k not in os.environ for k in cls.repository().required_env): return Failure(ValueError( - f"Missing required environment variables: {cls.metadata().required_env}", + f"Missing required environment variables: {cls.repository().required_env}", )) username: str = urllib.parse.quote(os.environ["CEDA_FTP_USER"]) password: str = urllib.parse.quote(os.environ["CEDA_FTP_PASS"]) @@ -219,17 +225,12 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: Args: url: The URL of the file to download. """ - if self._url_auth is None: - return Result.from_failure( - ValueError( - "Not authenticated with CEDA FTP server. " - "Ensure the 'authenticate' method has been called.", - ), - ) - local_path: pathlib.Path = ( pathlib.Path( - os.getenv("RAWDIR", f"~/.local/cache/nwp/{self.metadata().name}/raw"), + os.getenv( + "RAWDIR", + f"~/.local/cache/nwp/{self.repository().name}/{self.model().name}/raw", + ), ) / url.split("/")[-1] ) @@ -242,7 +243,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: timeout=30, ) except Exception as e: - return Result.from_failure(OSError(f"Error fetching {url}: {e}")) + return Failure(OSError(f"Error fetching {url}: {e}")) local_path.parent.mkdir(parents=True, exist_ok=True) log.debug("Downloading %s to %s", url, local_path) @@ -258,7 +259,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]: local_path.stat().st_size, ) except Exception as e: - return Result.from_failure( + return Failure( OSError( f"Error saving '{url}' to '{local_path}': {e}", ), @@ -293,7 +294,7 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]: ], ) .pipe(CedaMetOfficeGlobalModelRepository._rename_vars) - .to_dataarray(name=CedaMetOfficeGlobalModelRepository.metadata().name) + .to_dataarray(name=CedaMetOfficeGlobalModelRepository.model().name) .transpose("init_time", "step", "variable", "latitude", "longitude") # Remove the last value of the longitude dimension as it overlaps with the next file # Reverse the latitude dimension to be in descending order diff --git a/src/nwp_consumer/internal/repositories/model_repositories/test_metoffice_global.py b/src/nwp_consumer/internal/repositories/model_repositories/test_metoffice_global.py index 3d8007f6..2104b5ac 100644 --- a/src/nwp_consumer/internal/repositories/model_repositories/test_metoffice_global.py +++ b/src/nwp_consumer/internal/repositories/model_repositories/test_metoffice_global.py @@ -27,7 +27,7 @@ def test__download_and_convert(self) -> None: test_it: dt.datetime = dt.datetime(2021, 1, 1, 0, tzinfo=dt.UTC) test_coordinates: entities.NWPDimensionCoordinateMap = dataclasses.replace( - c.metadata().expected_coordinates, + c.model().expected_coordinates, init_time=[test_it], ) diff --git a/src/nwp_consumer/internal/repositories/notification_repositories/dagster.py b/src/nwp_consumer/internal/repositories/notification_repositories/dagster.py index 07bb2c85..0998f1f8 100644 --- a/src/nwp_consumer/internal/repositories/notification_repositories/dagster.py +++ b/src/nwp_consumer/internal/repositories/notification_repositories/dagster.py @@ -1,6 +1,6 @@ """Dagster pipes notification repository implementation. -`Dagster Pipes `_ +`Dagster Pipes `_ enables integration with Dagster for reporting asset materialization and logging. This module enables dagster instances running this code to recieve notifications. @@ -13,7 +13,7 @@ from typing import override from dagster_pipes import PipesContext, open_dagster_pipes -from returns.result import Result, ResultE +from returns.result import ResultE, Success from nwp_consumer.internal import entities, ports @@ -41,5 +41,5 @@ def notify( }, }, ) - return Result.from_value("Notification sent to dagster successfully.") + return Success("Notification sent to dagster successfully.") diff --git a/src/nwp_consumer/internal/repositories/notification_repositories/stdout.py b/src/nwp_consumer/internal/repositories/notification_repositories/stdout.py index eedae5ee..4437ba00 100644 --- a/src/nwp_consumer/internal/repositories/notification_repositories/stdout.py +++ b/src/nwp_consumer/internal/repositories/notification_repositories/stdout.py @@ -3,7 +3,7 @@ import logging from typing import override -from returns.result import Result, ResultE +from returns.result import ResultE, Success from nwp_consumer.internal import entities, ports @@ -19,5 +19,5 @@ def notify( message: entities.StoreCreatedNotification | entities.StoreAppendedNotification, ) -> ResultE[str]: log.info(f"{message}") - return Result.from_value("Notification sent to stdout successfully.") + return Success("Notification sent to stdout successfully.") diff --git a/src/nwp_consumer/internal/services/archiver_service.py b/src/nwp_consumer/internal/services/archiver_service.py index c88f03b6..160e2517 100644 --- a/src/nwp_consumer/internal/services/archiver_service.py +++ b/src/nwp_consumer/internal/services/archiver_service.py @@ -1,16 +1,18 @@ """Implementation of the NWP consumer services.""" import dataclasses -import datetime as dt import logging import pathlib -from typing import override +from typing import TYPE_CHECKING, override from joblib import Parallel -from returns.result import Failure, Result, ResultE, Success +from returns.result import Failure, ResultE, Success from nwp_consumer.internal import entities, ports +if TYPE_CHECKING: + import datetime as dt + log = logging.getLogger("nwp-consumer") @@ -22,60 +24,72 @@ class ArchiverService(ports.ArchiveUseCase): and writing it to a Zarr store. """ - _mr: ports.ModelRepository - _nr: ports.NotificationRepository + mr: type[ports.ModelRepository] + nr: type[ports.NotificationRepository] def __init__( self, - model_repository: ports.ModelRepository, - notification_repository: ports.NotificationRepository, + model_repository: type[ports.ModelRepository], + notification_repository: type[ports.NotificationRepository], ) -> None: """Create a new instance.""" - self._mr = model_repository - self._nr = notification_repository + self.mr = model_repository + self.nr = notification_repository @override def archive(self, year: int, month: int) -> ResultE[pathlib.Path]: monitor = entities.PerformanceMonitor() - init_times = self._mr.metadata().month_its(year=year, month=month) + init_times = self.mr.repository().month_its(year=year, month=month) # Create a store for the archive - init_store_result: ResultE[entities.TensorStore] = entities.TensorStore.initialize_empty_store( - name=self._mr.metadata().name, - coords=dataclasses.replace( - self._mr.metadata().expected_coordinates, - init_time=init_times, - ), - overwrite_existing=False, - ) + init_store_result: ResultE[entities.TensorStore] = \ + entities.TensorStore.initialize_empty_store( + name=self.mr.repository().name, + coords=dataclasses.replace( + self.mr.model().expected_coordinates, + init_time=init_times, + ), + overwrite_existing=False, + ) match init_store_result: case Failure(e): monitor.join() # TODO: Make this a context manager instead - return Result.from_failure(OSError( + return Failure(OSError( f"Failed to initialize store for {year}-{month}: {e}"), ) case Success(store): missing_times_result = store.missing_times() if isinstance(missing_times_result, Failure): monitor.join() - return Result.from_failure(missing_times_result.failure()) + return Failure(missing_times_result.failure()) log.info(f"{len(missing_times_result.unwrap())} missing init_times in store.") failed_times: list[dt.datetime] = [] for n, it in enumerate(missing_times_result.unwrap()): log.info( - f"Consuming data from {self._mr.metadata().name} for {it:%Y-%m-%d %H:%M} " + f"Consuming data from {self.mr.repository().name} for {it:%Y-%m-%d %H:%M} " f"(time {n + 1}/{len(missing_times_result.unwrap())})", ) + # Authenticate with the model repository + amr_result = self.mr.authenticate() + if isinstance(amr_result, Failure): + monitor.join() + return Failure(OSError( + "Unable to authenticate with model repository " + f"'{self.mr.repository().name}': " + f"{amr_result.failure()}", + )) + amr = amr_result.unwrap() + # Create a generator to fetch and process raw data da_result_generator = Parallel( - n_jobs=self._mr.metadata().max_connections - 1, + n_jobs=self.mr.repository().max_connections - 1, prefer="threads", return_as="generator_unordered", - )(self._mr.fetch_init_data(it=it)) + )(amr.fetch_init_data(it=it)) # Regionally write the results of the generator as they are ready for da_result in da_result_generator: @@ -96,11 +110,11 @@ def archive(self, year: int, month: int) -> ResultE[pathlib.Path]: # postprocess_result = store.postprocess(self._mr.metadata().postprocess_options) # if isinstance(postprocess_result, Failure): # monitor.join() # TODO: Make this a context manager instead - # return Result.from_failure(postprocess_result.failure()) + # return Failure(postprocess_result.failure()) monitor.join() - notify_result = self._nr.notify( - entities.StoreCreatedNotification( + notify_result = self.nr().notify( + message=entities.StoreCreatedNotification( filename=store.path.name, size_mb=store.size_mb, performance=entities.PerformanceMetadata( @@ -110,12 +124,14 @@ def archive(self, year: int, month: int) -> ResultE[pathlib.Path]: ), ) if isinstance(notify_result, Failure): - log.error("Failed to notify of store creation") - return notify_result + return Failure(OSError( + "Failed to notify of store creation: " + f"{notify_result.failure()}", + )) - return Result.from_value(store.path) + return Success(store.path) case _: - return Result.from_failure( + return Failure( TypeError(f"Unexpected result type: {type(init_store_result)}"), ) diff --git a/src/nwp_consumer/internal/services/consumer_service.py b/src/nwp_consumer/internal/services/consumer_service.py index 8ac5f94a..41e48f9b 100644 --- a/src/nwp_consumer/internal/services/consumer_service.py +++ b/src/nwp_consumer/internal/services/consumer_service.py @@ -7,7 +7,7 @@ from typing import override from joblib import Parallel -from returns.result import Failure, Result, ResultE, Success +from returns.result import Failure, ResultE, Success from nwp_consumer.internal import entities, ports @@ -22,35 +22,33 @@ class ConsumerService(ports.ConsumeUseCase): and writing it to a Zarr store. """ - _mr: ports.ModelRepository - _zr: ports.ZarrRepository - _nr: ports.NotificationRepository + mr: type[ports.ModelRepository] + nr: type[ports.NotificationRepository] def __init__( self, # TODO: 2024-10-21 - Work out how to pass none instantiated class values through DI - model_repository: ports.ModelRepository, - zarr_repository: ports.ZarrRepository, - notification_repository: ports.NotificationRepository, + model_repository: type[ports.ModelRepository], + notification_repository: type[ports.NotificationRepository], ) -> None: """Create a new instance.""" - self._mr = model_repository - self._zr = zarr_repository - self._nr = notification_repository + self.mr = model_repository + self.nr = notification_repository @override def consume(self, it: dt.datetime | None = None) -> ResultE[pathlib.Path]: monitor = entities.PerformanceMonitor() if it is None: - it = self._mr.metadata().determine_latest_it_from(dt.datetime.now(tz=dt.UTC)) - log.info(f"Consuming data from {self._mr.metadata().name} for {it:%Y-%m-%d %H:%M}") + it = self.mr.repository().determine_latest_it_from(dt.datetime.now(tz=dt.UTC)) + log.info(f"Consuming data from {self.mr.repository().name} for {it:%Y-%m-%d %H:%M}") # Create a store for the init time - init_store_result: ResultE[entities.TensorStore] = entities.TensorStore.initialize_empty_store( - name=self._mr.metadata().name, - coords=dataclasses.replace(self._mr.metadata().expected_coordinates, init_time=[it]), - ) + init_store_result: ResultE[entities.TensorStore] = \ + entities.TensorStore.initialize_empty_store( + name=self.mr.model().name, + coords=dataclasses.replace(self.mr.model().expected_coordinates, init_time=[it]), + ) match init_store_result: case Failure(e): @@ -59,11 +57,21 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[pathlib.Path]: case Success(store): # Create a generator to fetch and process raw data + amr_result = self.mr.authenticate() + if isinstance(amr_result, Failure): + monitor.join() + return Failure(OSError( + "Unable to authenticate with model repository " + f"'{self.mr.repository().name}': " + f"{amr_result.failure()}", + )) + amr = amr_result.unwrap() + fetch_result_generator = Parallel( - n_jobs=self._mr.metadata().max_connections - 1, + n_jobs=self.mr.repository().max_connections - 1, prefer="threads", return_as="generator_unordered", - )(self._mr.fetch_init_data(it=it)) + )(amr.fetch_init_data(it=it)) # Regionally write the results of the generator as they are ready for fetch_result in fetch_result_generator: @@ -71,7 +79,7 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[pathlib.Path]: monitor.join() return Failure(OSError( f"Error fetching data for init time '{it:%Y-%m-%d %H:%M}' ", - f"and model {self._mr.metadata().name}: {fetch_result.failure()}", + f"and model {self.mr.repository().name}: {fetch_result.failure()}", )) for da in fetch_result.unwrap(): write_result = store.write_to_region(da) @@ -81,20 +89,20 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[pathlib.Path]: monitor.join() # TODO: Make this a context manager instead return Failure(OSError( f"Error writing data for init time '{it:%Y-%m-%d %H:%M}' ", - f"and model {self._mr.metadata().name}: {write_result.failure()}", + f"and model {self.mr.repository().name}: {write_result.failure()}", )) del fetch_result_generator # Postprocess the dataset as required - postprocess_result = store.postprocess(self._mr.metadata().postprocess_options) + postprocess_result = store.postprocess(self.mr.repository().postprocess_options) if isinstance(postprocess_result, Failure): monitor.join() # TODO: Make this a context manager instead return Failure(postprocess_result.failure()) monitor.join() - notify_result = self._nr.notify( - entities.StoreCreatedNotification( + notify_result = self.nr().notify( + message=entities.StoreCreatedNotification( filename=store.path.name, size_mb=store.size_mb, performance=entities.PerformanceMetadata( @@ -104,8 +112,10 @@ def consume(self, it: dt.datetime | None = None) -> ResultE[pathlib.Path]: ), ) if isinstance(notify_result, Failure): - log.error("Failed to notify of store creation") - return notify_result + return Failure(OSError( + "Failed to notify of store creation: " + f"{notify_result.failure()}", + )) return Success(store.path) diff --git a/src/nwp_consumer/internal/services/test_consumer.py b/src/nwp_consumer/internal/services/test_consumer.py index fa2fece4..6102c475 100644 --- a/src/nwp_consumer/internal/services/test_consumer.py +++ b/src/nwp_consumer/internal/services/test_consumer.py @@ -8,7 +8,7 @@ import xarray as xr from joblib import delayed from returns.pipeline import is_successful -from returns.result import Result, ResultE +from returns.result import ResultE, Success from nwp_consumer.internal import entities, ports from nwp_consumer.internal.services.consumer_service import ConsumerService @@ -16,12 +16,16 @@ class DummyModelRepository(ports.ModelRepository): + @classmethod + @override + def authenticate(cls) -> ResultE["DummyModelRepository"]: + return Success(cls()) + @staticmethod @override - def metadata() -> entities.ModelRepositoryMetadata: - """See parent class.""" + def repository() -> entities.ModelRepositoryMetadata: return entities.ModelRepositoryMetadata( - name="dummy", + name="ACME-Test-Models", is_archive=False, is_order_based=False, running_hours=[0, 6, 12, 18], @@ -30,6 +34,14 @@ def metadata() -> entities.ModelRepositoryMetadata: required_env=[], optional_env={}, postprocess_options=entities.PostProcessOptions(), + ) + + @staticmethod + @override + def model() -> entities.ModelMetadata: + return entities.ModelMetadata( + name="simple-random", + resolution="17km", expected_coordinates=entities.NWPDimensionCoordinateMap( init_time=[dt.datetime(2021, 1, 1, 0, 0, tzinfo=dt.UTC)], step=list(range(0, 48, 1)), @@ -43,28 +55,28 @@ def metadata() -> entities.ModelRepositoryMetadata: ), ) + @override def fetch_init_data(self, it: dt.datetime) \ -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]: - """See parent class.""" def gen_dataset(step: int, variable: str) -> ResultE[list[xr.DataArray]]: """Define a generator that provides one variable at one step.""" da = xr.DataArray( - name=self.metadata().name, + name=self.repository().name, dims=["init_time", "step", "variable", "latitude", "longitude"], data=np.random.rand(1, 1, 1, 721, 1440), - coords=self.metadata().expected_coordinates.to_pandas() | { + coords=self.model().expected_coordinates.to_pandas() | { "init_time": [np.datetime64(it.replace(tzinfo=None), "ns")], "step": [step], "variable": [variable], }, ) - return Result.from_value([da]) + return Success([da]) - for s in self.metadata().expected_coordinates.step: - for v in self.metadata().expected_coordinates.variable: + for s in self.model().expected_coordinates.step: + for v in self.model().expected_coordinates.variable: yield delayed(gen_dataset)(s, v.value) @@ -76,8 +88,7 @@ def notify( message: entities.StoreAppendedNotification | entities.StoreCreatedNotification, ) -> ResultE[str]: """See parent class.""" - print(message) - return Result.from_value(str(message)) + return Success(str(message)) class DummyZarrRepository(ports.ZarrRepository): @@ -85,7 +96,7 @@ class DummyZarrRepository(ports.ZarrRepository): @override def save(self, src: pathlib.Path, dst: pathlib.Path) -> ResultE[str]: """See parent class.""" - return Result.from_value(str(dst)) + return Success(str(dst)) class TestParallelConsumer(unittest.TestCase): @@ -94,9 +105,8 @@ def test_consume(self) -> None: """Test the consume method of the ParallelConsumer class.""" test_consumer = ConsumerService( - model_repository=DummyModelRepository(), - notification_repository=DummyNotificationRepository(), - zarr_repository=DummyZarrRepository(), + model_repository=DummyModelRepository, + notification_repository=DummyNotificationRepository, ) result = test_consumer.consume(it=dt.datetime(2021, 1, 1, tzinfo=dt.UTC)) diff --git a/src/test_integration/test_integration.py b/src/test_integration/test_integration.py index 340fb31e..fe1854a2 100644 --- a/src/test_integration/test_integration.py +++ b/src/test_integration/test_integration.py @@ -2,24 +2,27 @@ import unittest import xarray as xr -from nwp_consumer.internal import handlers, repositories, services from returns.pipeline import is_successful +from nwp_consumer.internal import handlers, repositories, services + class TestIntegration(unittest.TestCase): def test_ceda_metoffice_global_model(self) -> None: c = handlers.CLIHandler( consumer_usecase=services.ConsumerService( - model_repository=repositories.CedaMetOfficeGlobalModelRepository(), - notification_repository=repositories.StdoutNotificationRepository(), - zarr_repository=None, + model_repository=repositories.CedaMetOfficeGlobalModelRepository, + notification_repository=repositories.StdoutNotificationRepository, + ), + archiver_usecase=services.ArchiverService( + model_repository=repositories.CedaMetOfficeGlobalModelRepository, + notification_repository=repositories.StdoutNotificationRepository, ), - archiver_usecase=None, ) result = c._consumer_usecase.consume(it=dt.datetime(2021, 1, 1, tzinfo=dt.UTC)) self.assertTrue(is_successful(result), msg=f"{result}") da = xr.open_dataarray(result.unwrap(), engine="zarr") - print(da) + self.assertTrue(da.sizes["init_time"] > 0)