diff --git a/.gitignore b/.gitignore index a9b169b..4d99463 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,3 @@ -Untitled.ipynb -/package-lock.json -/node_modules -.vscode -.idea - ### ArchLinuxPackages ### *.tar *.tar.* @@ -156,6 +150,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +**/.benchmarks cover/ # Translations @@ -179,6 +174,7 @@ target/ # Jupyter Notebook .ipynb_checkpoints +Untitled.ipynb # IPython profile_default/ @@ -261,7 +257,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ ### Git ### # Created by git for backups. To disable backups in Git: @@ -578,12 +574,7 @@ tags [._]*.un~ ### VisualStudioCode ### -.vscode/* -!.vscode/settings.json -!.vscode/tasks.json -!.vscode/launch.json -!.vscode/extensions.json -!.vscode/*.code-snippets +.vscode/ # Local History for Visual Studio Code .history/ @@ -936,6 +927,7 @@ FakesAssemblies/ # Node.js Tools for Visual Studio .ntvs_analysis.dat node_modules/ +/package-lock.json # Visual Studio 6 build log *.plg @@ -1036,5 +1028,3 @@ FodyWeavers.xsd ### VisualStudio Patch ### # Additional files built by Visual Studio - -# End of https://www.toptal.com/developers/gitignore/api/linux,archlinuxpackages,osx,windows,python,c,django,database,pycharm,visualstudio,visualstudiocode,vim,zsh,git,diff,microsoftoffice,spreadsheet,ssh,certificates diff --git a/CHANGELOG.md b/CHANGELOG.md index 615be94..d4d3eca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add the missing JSON schema `item_assets` definition under a Collection to ensure compatibility with the [Item Assets](https://github.com/stac-extensions/item-assets) extension, as mentioned this specification. +- Add `ModelBand` representation using `name`, `format` and `expression` properties to allow derived band references + (fixes [crim-ca/mlm-extension#7](https://github.com/crim-ca/mlm-extension/discussions/7)). ### Changed - Adds a job to publish.yaml to publish the stac-model package diff --git a/Makefile b/Makefile index 3093616..6c0fa19 100644 --- a/Makefile +++ b/Makefile @@ -81,6 +81,10 @@ lint: .PHONY: check-lint check-lint: lint +.PHONY: format-lint +format-lint: + poetry run ruff --config=pyproject.toml --fix ./ + .PHONY: install-npm install-npm: npm install @@ -101,7 +105,8 @@ check-examples: install-npm format-examples: install-npm npm run format-examples -fix-%: format-%s +FORMATTERS := lint markdown examples +$(addprefix fix-, $(FORMATTERS)): fix-%: format-% .PHONY: lint-all lint-all: lint mypy check-safety check-markdown diff --git a/README.md b/README.md index 6ee4fc8..a80a2a0 100644 --- a/README.md +++ b/README.md @@ -209,18 +209,18 @@ set to `true`, there would be no `accelerator` to contain against. To avoid conf ### Model Input Object -| Field Name | Type | Description | -|-------------------------|---------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. | -| bands | \[string] | **REQUIRED** The names of the raster bands used to train or fine-tune the model, which may be all or a subset of bands available in a STAC Item's [Band Object](#bands-and-statistics). If no band applies for one input, use an empty array. | -| input | [Input Structure Object](#input-structure-object) | **REQUIRED** The N-dimensional array definition that describes the shape, dimension ordering, and data type. | -| description | string | Additional details about the input such as describing its purpose or expected source that cannot be represented by other properties. | -| norm_by_channel | boolean | Whether to normalize each channel by channel-wise statistics or to normalize by dataset statistics. If True, use an array of `statistics` of same dimensionality and order as the `bands` field in this object. | -| norm_type | [Normalize Enum](#normalize-enum) \| null | Normalization method. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. | -| norm_clip | \[number] | When `norm_type = "clip"`, this array supplies the value for each `bands` item, which is used to divide each band before clipping values between 0 and 1. | -| resize_type | [Resize Enum](#resize-enum) \| null | High-level descriptor of the rescaling method to change image shape. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. | -| statistics | \[[Statistics Object](#bands-and-statistics)] | Dataset statistics for the training dataset used to normalize the inputs. | -| pre_processing_function | [Processing Expression](#processing-expression) \| null | Custom preprocessing function where normalization and rescaling, and any other significant operations takes place. | +| Field Name | Type | Description | +|-------------------------|---------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. | +| bands | \[string \| [Model Band Object](#model-band-object)] | **REQUIRED** The raster band references used to train, fine-tune or perform inference with the model, which may be all or a subset of bands available in a STAC Item's [Band Object](#bands-and-statistics). If no band applies for one input, use an empty array. | +| input | [Input Structure Object](#input-structure-object) | **REQUIRED** The N-dimensional array definition that describes the shape, dimension ordering, and data type. | +| description | string | Additional details about the input such as describing its purpose or expected source that cannot be represented by other properties. | +| norm_by_channel | boolean | Whether to normalize each channel by channel-wise statistics or to normalize by dataset statistics. If True, use an array of `statistics` of same dimensionality and order as the `bands` field in this object. | +| norm_type | [Normalize Enum](#normalize-enum) \| null | Normalization method. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. | +| norm_clip | \[number] | When `norm_type = "clip"`, this array supplies the value for each `bands` item, which is used to divide each band before clipping values between 0 and 1. | +| resize_type | [Resize Enum](#resize-enum) \| null | High-level descriptor of the rescaling method to change image shape. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. | +| statistics | \[[Statistics Object](#bands-and-statistics)] | Dataset statistics for the training dataset used to normalize the inputs. | +| pre_processing_function | [Processing Expression](#processing-expression) \| null | Custom preprocessing function where normalization and rescaling, and any other significant operations takes place. The `pre_processing_function` should be applied over all available `bands`. For respective band operations, see [Model Band Object](#model-band-object). | Fields that accept the `null` value can be considered `null` when omitted entirely for parsing purposes. However, setting `null` explicitly when this information is known by the model provider can help users understand @@ -253,6 +253,9 @@ and [Common Band Names][stac-band-names]. Only bands used as input to the model should be included in the MLM `bands` field. To avoid duplicating the information, MLM only uses the `name` of whichever "Band Object" is defined in the STAC Item. +An input's `bands` definition can either be a plain `string` or a [Model Band Object](#model-band-object). +When a `string` is employed directly, the value should be implicitly mapped to the `name` property of the +explicit object representation. One distinction from the [STAC 1.1 - Band Object][stac-1.1-band] in MLM is that [Statistics][stac-1.1-stats] object (or the corresponding [STAC Raster - Statistics][stac-raster-stats] for STAC 1.0) are not @@ -269,6 +272,29 @@ properties of the model. [stac-raster-stats]: https://github.com/stac-extensions/raster?tab=readme-ov-file#statistics-object [stac-band-names]: https://github.com/stac-extensions/eo?tab=readme-ov-file#common-band-names +#### Model Band Object + +| Field Name | Type | Description | +|------------|--------|----------------------------------------------------------------------------------------------------------------------------------------| +| name | string | **REQUIRED** Name of the band referring to an extended band definition (see [Bands](#bands-and-statistics). | +| format | string | The type of expression that is specified in the `expression` property. | +| expression | \* | An expression compliant with the `format` specified. The expression can be applied to any data type and depends on the `format` given. | + +> :information_source:
+> Although `format` and `expression` are not required in this context, they are mutually dependent on each other.
+> See also [Processing Expression](#processing-expression) for more details and examples. + +The `format` and `expression` properties can serve multiple purpose. + +1. Applying a band-specific pre-processing step, + in contrast to [`pre_processing_function`](#model-input-object) applied over all bands. + For example, reshaping a band to align its dimensions with other bands before stacking them. + +2. Defining a derived-band operation or a calculation that produces a virtual band from other band references. + For example, computing an indice that applies an arithmetic combination of other bands. + +For a concrete example, see [examples/item_bands_expression.json](examples/item_bands_expression.json). + #### Data Type Enum When describing the `data_type` provided by a [Band](#bands-and-statistics), whether for defining diff --git a/examples/collection.json b/examples/collection.json index 46c78ff..aff632b 100644 --- a/examples/collection.json +++ b/examples/collection.json @@ -52,6 +52,10 @@ "href": "item_basic.json", "rel": "item" }, + { + "href": "item_bands_expression.json", + "rel": "item" + }, { "href": "item_eo_bands.json", "rel": "item" diff --git a/examples/item_bands_expression.json b/examples/item_bands_expression.json new file mode 100644 index 0000000..3fdd4aa --- /dev/null +++ b/examples/item_bands_expression.json @@ -0,0 +1,204 @@ +{ + "$comment": "Demonstrate the use of MLM and EO for bands description, with EO bands directly in the Model Asset.", + "stac_version": "1.0.0", + "stac_extensions": [ + "https://crim-ca.github.io/mlm-extension/v1.1.0/schema.json", + "https://stac-extensions.github.io/eo/v1.1.0/schema.json", + "https://stac-extensions.github.io/raster/v1.1.0/schema.json", + "https://stac-extensions.github.io/file/v1.0.0/schema.json", + "https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json" + ], + "type": "Feature", + "id": "resnet-18_sentinel-2_all_moco_classification", + "collection": "ml-model-examples", + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + -7.882190080512502, + 37.13739173208318 + ], + [ + -7.882190080512502, + 58.21798141355221 + ], + [ + 27.911651652899923, + 58.21798141355221 + ], + [ + 27.911651652899923, + 37.13739173208318 + ], + [ + -7.882190080512502, + 37.13739173208318 + ] + ] + ] + }, + "bbox": [ + -7.882190080512502, + 37.13739173208318, + 27.911651652899923, + 58.21798141355221 + ], + "properties": { + "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO", + "datetime": null, + "start_datetime": "1900-01-01T00:00:00Z", + "end_datetime": "9999-12-31T23:59:59Z", + "mlm:name": "Resnet-18 Sentinel-2 ALL MOCO", + "mlm:tasks": [ + "classification" + ], + "mlm:architecture": "ResNet", + "mlm:framework": "pytorch", + "mlm:framework_version": "2.1.2+cu121", + "file:size": 43000000, + "mlm:memory_size": 1, + "mlm:total_parameters": 11700000, + "mlm:pretrained_source": "EuroSat Sentinel-2", + "mlm:accelerator": "cuda", + "mlm:accelerator_constrained": false, + "mlm:accelerator_summary": "Unknown", + "mlm:batch_size_suggestion": 256, + "mlm:input": [ + { + "name": "RBG+NDVI Bands Sentinel-2 Batch", + "bands": [ + { + "name": "B04" + }, + { + "name": "B03" + }, + { + "name": "B02" + }, + { + "name": "NDVI", + "format": "rio-calc", + "expression": "(B08 - B04) / (B08 + B04)" + } + ], + "input": { + "shape": [ + -1, + 13, + 64, + 64 + ], + "dim_order": [ + "batch", + "channel", + "height", + "width" + ], + "data_type": "float32" + } + } + ], + "mlm:output": [ + { + "name": "classification", + "tasks": [ + "segmentation", + "semantic-segmentation" + ], + "result": { + "shape": [ + -1, + 10 + ], + "dim_order": [ + "batch", + "class" + ], + "data_type": "float32" + }, + "classification_classes": [ + { + "value": 1, + "name": "vegetation", + "title": "Vegetation", + "description": "Pixels were vegetation is detected.", + "color_hint": "00FF00", + "nodata": false + }, + { + "value": 0, + "name": "background", + "title": "Non-Vegetation", + "description": "Anything that is not classified as vegetation.", + "color_hint": "000000", + "nodata": false + } + ], + "post_processing_function": null + } + ] + }, + "assets": { + "weights": { + "href": "https://example.com/model-rgb-ndvi.pth", + "title": "Pytorch weights checkpoint", + "description": "A vegetation classification model trained on Sentinel-2 imagery and NDVI.", + "type": "application/octet-stream; application=pytorch", + "roles": [ + "mlm:model", + "mlm:weights" + ], + "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.", + "eo:bands": [ + { + "name": "B02", + "common_name": "blue", + "description": "Blue (band 2)", + "center_wavelength": 0.49, + "full_width_half_max": 0.098 + }, + { + "name": "B03", + "common_name": "green", + "description": "Green (band 3)", + "center_wavelength": 0.56, + "full_width_half_max": 0.045 + }, + { + "name": "B04", + "common_name": "red", + "description": "Red (band 4)", + "center_wavelength": 0.665, + "full_width_half_max": 0.038 + }, + { + "name": "B08", + "common_name": "nir", + "description": "NIR 1 (band 8)", + "center_wavelength": 0.842, + "full_width_half_max": 0.145 + } + ] + } + }, + "links": [ + { + "rel": "collection", + "href": "./collection.json", + "type": "application/json" + }, + { + "rel": "self", + "href": "./item_bands_expression.json", + "type": "application/geo+json" + }, + { + "rel": "derived_from", + "href": "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a", + "type": "application/json", + "ml-aoi:split": "train" + } + ] +} diff --git a/json-schema/schema.json b/json-schema/schema.json index 21cd5f2..b31c73f 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -711,13 +711,45 @@ } }, "ModelBands": { + "description": "List of bands (if any) that compose the input. Band order represents the index position of the bands.", "allOf": [ { "$comment": "No 'minItems' here to support model inputs not using any band (other data source).", "type": "array", "items": { - "type": "string", - "minLength": 1 + "oneOf": [ + { + "description": "Implied named-band with the name directly provided.", + "type": "string", + "minLength": 1 + }, + { + "description": "Explicit named-band with optional derived expression to obtain it.", + "type": "object", + "required": [ + "name" + ], + "properties": { + "name": { + "type": "string", + "minLength": 1 + }, + "format": { + "description": "Format to interpret the specified expression used to obtain the band.", + "type": "string", + "minLength": 1 + }, + "expression": { + "description": "Any representation relevant for the specified 'format'." + } + }, + "dependencies": { + "format": ["expression"], + "expression": ["format"] + }, + "additionalProperties": false + } + ] } }, { diff --git a/pyproject.toml b/pyproject.toml index 8e52e60..c5ad032 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,7 @@ select = [ [tool.ruff.lint.isort] known-local-folder = ["tests", "conftest"] known-first-party = ["stac_model"] +extra-standard-library = ["typing_extensions"] [tool.mypy] # https://github.com/python/mypy diff --git a/stac_model/base.py b/stac_model/base.py index 4e8cc6b..c5ff7c6 100644 --- a/stac_model/base.py +++ b/stac_model/base.py @@ -116,7 +116,7 @@ class TaskEnum(str, Enum): ModelTask = Union[ModelTaskNames, TaskEnum] -class ProcessingExpression(BaseModel): +class ProcessingExpression(MLMBaseModel): # FIXME: should use 'pystac' reference, but 'processing' extension is not implemented yet! format: str expression: Any diff --git a/stac_model/examples.py b/stac_model/examples.py index c781a5b..47be1db 100644 --- a/stac_model/examples.py +++ b/stac_model/examples.py @@ -63,7 +63,10 @@ def eurosat_resnet() -> ItemMLModelExtension: 1231.58581042, ] stats = [ - MLMStatistic(mean=mean, stddev=stddev) + MLMStatistic( + mean=mean, + stddev=stddev, + ) for mean, stddev in zip(stats_mean, stats_stddev) ] model_input = ModelInput( @@ -82,7 +85,7 @@ def eurosat_resnet() -> ItemMLModelExtension: result_struct = ModelResult( shape=[-1, 10], dim_order=["batch", "class"], - data_type="float32" + data_type="float32", ) class_map = { "Annual Crop": 0, @@ -97,7 +100,10 @@ def eurosat_resnet() -> ItemMLModelExtension: "SeaLake": 9, } class_objects = [ - MLMClassification(value=class_value, name=class_name) + MLMClassification( + value=class_value, + name=class_name, + ) for class_name, class_value in class_map.items() ] model_output = ModelOutput( @@ -119,8 +125,8 @@ def eurosat_resnet() -> ItemMLModelExtension: roles=[ "mlm:model", "mlm:weights", - "data" - ] + "data", + ], ), "source_code": pystac.Asset( title="Model implementation.", @@ -129,9 +135,9 @@ def eurosat_resnet() -> ItemMLModelExtension: media_type="text/x-python", roles=[ "mlm:model", - "code" - ] - ) + "code", + ], + ), } ml_model_size = 43000000 @@ -163,7 +169,7 @@ def eurosat_resnet() -> ItemMLModelExtension: -7.882190080512502, 37.13739173208318, 27.911651652899923, - 58.21798141355221 + 58.21798141355221, ] geometry = shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__ item_name = "item_basic" @@ -177,9 +183,7 @@ def eurosat_resnet() -> ItemMLModelExtension: properties={ "start_datetime": start_datetime, "end_datetime": end_datetime, - "description": ( - "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO" - ), + "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO", }, assets=assets, ) @@ -202,7 +206,7 @@ def eurosat_resnet() -> ItemMLModelExtension: extent=pystac.Extent( temporal=pystac.TemporalExtent([[parse_dt(start_datetime), parse_dt(end_datetime)]]), spatial=pystac.SpatialExtent([bbox]), - ) + ), ) col.set_self_href("./examples/collection.json") col.add_item(item) @@ -210,7 +214,7 @@ def eurosat_resnet() -> ItemMLModelExtension: model_asset = cast( FileExtension[pystac.Asset], - pystac.extensions.file.FileExtension.ext(assets["model"], add_if_missing=True) + pystac.extensions.file.FileExtension.ext(assets["model"], add_if_missing=True), ) model_asset.apply(size=ml_model_size) diff --git a/stac_model/input.py b/stac_model/input.py index 19c6e13..22788d7 100644 --- a/stac_model/input.py +++ b/stac_model/input.py @@ -1,6 +1,7 @@ -from typing import Annotated, List, Literal, Optional, TypeAlias, Union +from typing import Annotated, Any, List, Literal, Optional, Sequence, TypeAlias, Union +from typing_extensions import Self -from pydantic import Field +from pydantic import Field, model_validator from stac_model.base import DataType, MLMBaseModel, Number, OmitIfNone, ProcessingExpression @@ -10,6 +11,12 @@ class InputStructure(MLMBaseModel): dim_order: List[str] = Field(min_items=1) data_type: DataType + @model_validator(mode="after") + def validate_dimensions(self) -> Self: + if len(self.shape) != len(self.dim_order): + raise ValueError("Dimension order and shape must be of equal length for corresponding indices.") + return self + class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster extension (cases required to be triggered) minimum: Annotated[Optional[Number], OmitIfNone] = None @@ -31,7 +38,7 @@ class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster ext "hamming2", "type-mask", "relative", - "inf" + "inf", ] ] @@ -51,9 +58,54 @@ class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster ext ] +class ModelBand(MLMBaseModel): + name: str = Field( + description=( + "Name of the band to use for the input, " + "referring to the name of an entry in a 'bands' definition from another STAC extension." + ) + ) + # similar to 'ProcessingExpression', but they can be omitted here + format: Annotated[Optional[str], OmitIfNone] = Field( + default=None, + description="", + ) + expression: Annotated[Optional[Any], OmitIfNone] = Field( + default=None, + description="", + ) + + @model_validator(mode="after") + def validate_expression(self) -> Self: + if ( # mutually dependant + (self.format is not None or self.expression is not None) + and (self.format is None or self.expression is None) + ): + raise ValueError("Model band 'format' and 'expression' are mutually dependant.") + return self + + class ModelInput(MLMBaseModel): name: str - bands: List[str] # order is critical here (same index as dim shape), allow duplicate if the model needs it somehow + # order is critical here (same index as dim shape), allow duplicate if the model needs it somehow + bands: Sequence[Union[str, ModelBand]] = Field( + description=( + "List of bands that compose the input. " + "If a string is used, it is implied to correspond to a named-band. " + "If no band is needed for the input, use an empty array." + ), + examples=[ + [ + "B01", + {"name": "B02"}, + { + "name": "NDVI", + "format": "rio-calc", + "expression": "(B08 - B04) / (B08 + B04)", + }, + ], + ], + ) input: InputStructure norm_by_channel: Annotated[Optional[bool], OmitIfNone] = None norm_type: Annotated[Optional[NormalizeType], OmitIfNone] = None diff --git a/stac_model/schema.py b/stac_model/schema.py index a9c8146..db7ae37 100644 --- a/stac_model/schema.py +++ b/stac_model/schema.py @@ -97,15 +97,18 @@ def get_schema_uri(cls) -> str: @overload @classmethod - def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> "AssetMLModelExtension": ... + def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> "AssetMLModelExtension": + ... @overload @classmethod - def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> "ItemMLModelExtension": ... + def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> "ItemMLModelExtension": + ... @overload @classmethod - def ext(cls, obj: pystac.Collection, add_if_missing: bool = False) -> "CollectionMLModelExtension": ... + def ext(cls, obj: pystac.Collection, add_if_missing: bool = False) -> "CollectionMLModelExtension": + ... # @overload # @classmethod diff --git a/tests/conftest.py b/tests/conftest.py index 75c81b5..417a913 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,8 @@ def get_all_stac_item_examples() -> List[str]: all_json = glob.glob("**/*.json", root_dir=EXAMPLES_DIR, recursive=True) all_geojson = glob.glob("**/*.geojson", root_dir=EXAMPLES_DIR, recursive=True) all_stac_items = [ - path for path in all_json + all_geojson + path + for path in all_json + all_geojson if os.path.splitext(os.path.basename(path))[0] not in ["collection", "catalog"] ] return all_stac_items diff --git a/tests/test_stac_model.py b/tests/test_stac_model.py new file mode 100644 index 0000000..83d169c --- /dev/null +++ b/tests/test_stac_model.py @@ -0,0 +1,52 @@ +import pydantic +import pytest + +from stac_model.input import InputStructure, ModelBand, ModelInput + + +@pytest.mark.parametrize( + "bands", + [ + ["B04", "B03", "B02"], + [{"name": "B04"}, {"name": "B03"}, {"name": "B02"}], + [{"name": "NDVI", "format": "rio-calc", "expression": "(B08 - B04) / (B08 + B04)"}], + [ + "B04", + {"name": "B03"}, + "B02", + {"name": "NDVI", "format": "rio-calc", "expression": "(B08 - B04) / (B08 + B04)"}, + ], + ], +) +def test_model_band(bands): + mlm_input = ModelInput( + name="test", + bands=bands, + input=InputStructure( + shape=[-1, len(bands), 64, 64], + dim_order=["batch", "channel", "height", "width"], + data_type="float32", + ), + ) + mlm_bands = mlm_input.dict()["bands"] + assert mlm_bands == bands + + +@pytest.mark.parametrize( + "bands", + [ + [{"name": "test", "expression": "missing-format"}], + [{"name": "test", "format": "missing-expression"}], + ], +) +def test_model_band_format_expression_dependency(bands: list[ModelBand]) -> None: + with pytest.raises(pydantic.ValidationError): + ModelInput( + name="test", + bands=bands, + input=InputStructure( + shape=[-1, len(bands), 64, 64], + dim_order=["batch", "channel", "height", "width"], + data_type="float32", + ), + )