diff --git a/.gitignore b/.gitignore
index a9b169b..4d99463 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,9 +1,3 @@
-Untitled.ipynb
-/package-lock.json
-/node_modules
-.vscode
-.idea
-
### ArchLinuxPackages ###
*.tar
*.tar.*
@@ -156,6 +150,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
+**/.benchmarks
cover/
# Translations
@@ -179,6 +174,7 @@ target/
# Jupyter Notebook
.ipynb_checkpoints
+Untitled.ipynb
# IPython
profile_default/
@@ -261,7 +257,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
+.idea/
### Git ###
# Created by git for backups. To disable backups in Git:
@@ -578,12 +574,7 @@ tags
[._]*.un~
### VisualStudioCode ###
-.vscode/*
-!.vscode/settings.json
-!.vscode/tasks.json
-!.vscode/launch.json
-!.vscode/extensions.json
-!.vscode/*.code-snippets
+.vscode/
# Local History for Visual Studio Code
.history/
@@ -936,6 +927,7 @@ FakesAssemblies/
# Node.js Tools for Visual Studio
.ntvs_analysis.dat
node_modules/
+/package-lock.json
# Visual Studio 6 build log
*.plg
@@ -1036,5 +1028,3 @@ FodyWeavers.xsd
### VisualStudio Patch ###
# Additional files built by Visual Studio
-
-# End of https://www.toptal.com/developers/gitignore/api/linux,archlinuxpackages,osx,windows,python,c,django,database,pycharm,visualstudio,visualstudiocode,vim,zsh,git,diff,microsoftoffice,spreadsheet,ssh,certificates
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 615be94..d4d3eca 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Add the missing JSON schema `item_assets` definition under a Collection to ensure compatibility with
the [Item Assets](https://github.com/stac-extensions/item-assets) extension, as mentioned this specification.
+- Add `ModelBand` representation using `name`, `format` and `expression` properties to allow derived band references
+ (fixes [crim-ca/mlm-extension#7](https://github.com/crim-ca/mlm-extension/discussions/7)).
### Changed
- Adds a job to publish.yaml to publish the stac-model package
diff --git a/Makefile b/Makefile
index 3093616..6c0fa19 100644
--- a/Makefile
+++ b/Makefile
@@ -81,6 +81,10 @@ lint:
.PHONY: check-lint
check-lint: lint
+.PHONY: format-lint
+format-lint:
+ poetry run ruff --config=pyproject.toml --fix ./
+
.PHONY: install-npm
install-npm:
npm install
@@ -101,7 +105,8 @@ check-examples: install-npm
format-examples: install-npm
npm run format-examples
-fix-%: format-%s
+FORMATTERS := lint markdown examples
+$(addprefix fix-, $(FORMATTERS)): fix-%: format-%
.PHONY: lint-all
lint-all: lint mypy check-safety check-markdown
diff --git a/README.md b/README.md
index 6ee4fc8..a80a2a0 100644
--- a/README.md
+++ b/README.md
@@ -209,18 +209,18 @@ set to `true`, there would be no `accelerator` to contain against. To avoid conf
### Model Input Object
-| Field Name | Type | Description |
-|-------------------------|---------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. |
-| bands | \[string] | **REQUIRED** The names of the raster bands used to train or fine-tune the model, which may be all or a subset of bands available in a STAC Item's [Band Object](#bands-and-statistics). If no band applies for one input, use an empty array. |
-| input | [Input Structure Object](#input-structure-object) | **REQUIRED** The N-dimensional array definition that describes the shape, dimension ordering, and data type. |
-| description | string | Additional details about the input such as describing its purpose or expected source that cannot be represented by other properties. |
-| norm_by_channel | boolean | Whether to normalize each channel by channel-wise statistics or to normalize by dataset statistics. If True, use an array of `statistics` of same dimensionality and order as the `bands` field in this object. |
-| norm_type | [Normalize Enum](#normalize-enum) \| null | Normalization method. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. |
-| norm_clip | \[number] | When `norm_type = "clip"`, this array supplies the value for each `bands` item, which is used to divide each band before clipping values between 0 and 1. |
-| resize_type | [Resize Enum](#resize-enum) \| null | High-level descriptor of the rescaling method to change image shape. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. |
-| statistics | \[[Statistics Object](#bands-and-statistics)] | Dataset statistics for the training dataset used to normalize the inputs. |
-| pre_processing_function | [Processing Expression](#processing-expression) \| null | Custom preprocessing function where normalization and rescaling, and any other significant operations takes place. |
+| Field Name | Type | Description |
+|-------------------------|---------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| name | string | **REQUIRED** Name of the input variable defined by the model. If no explicit name is defined by the model, an informative name (e.g.: `"RGB Time Series"`) can be used instead. |
+| bands | \[string \| [Model Band Object](#model-band-object)] | **REQUIRED** The raster band references used to train, fine-tune or perform inference with the model, which may be all or a subset of bands available in a STAC Item's [Band Object](#bands-and-statistics). If no band applies for one input, use an empty array. |
+| input | [Input Structure Object](#input-structure-object) | **REQUIRED** The N-dimensional array definition that describes the shape, dimension ordering, and data type. |
+| description | string | Additional details about the input such as describing its purpose or expected source that cannot be represented by other properties. |
+| norm_by_channel | boolean | Whether to normalize each channel by channel-wise statistics or to normalize by dataset statistics. If True, use an array of `statistics` of same dimensionality and order as the `bands` field in this object. |
+| norm_type | [Normalize Enum](#normalize-enum) \| null | Normalization method. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. |
+| norm_clip | \[number] | When `norm_type = "clip"`, this array supplies the value for each `bands` item, which is used to divide each band before clipping values between 0 and 1. |
+| resize_type | [Resize Enum](#resize-enum) \| null | High-level descriptor of the rescaling method to change image shape. Select an appropriate option or `null` when none applies. Consider using `pre_processing_function` for custom implementations or more complex combinations. |
+| statistics | \[[Statistics Object](#bands-and-statistics)] | Dataset statistics for the training dataset used to normalize the inputs. |
+| pre_processing_function | [Processing Expression](#processing-expression) \| null | Custom preprocessing function where normalization and rescaling, and any other significant operations takes place. The `pre_processing_function` should be applied over all available `bands`. For respective band operations, see [Model Band Object](#model-band-object). |
Fields that accept the `null` value can be considered `null` when omitted entirely for parsing purposes.
However, setting `null` explicitly when this information is known by the model provider can help users understand
@@ -253,6 +253,9 @@ and [Common Band Names][stac-band-names].
Only bands used as input to the model should be included in the MLM `bands` field.
To avoid duplicating the information, MLM only uses the `name` of whichever "Band Object" is defined in the STAC Item.
+An input's `bands` definition can either be a plain `string` or a [Model Band Object](#model-band-object).
+When a `string` is employed directly, the value should be implicitly mapped to the `name` property of the
+explicit object representation.
One distinction from the [STAC 1.1 - Band Object][stac-1.1-band] in MLM is that [Statistics][stac-1.1-stats] object
(or the corresponding [STAC Raster - Statistics][stac-raster-stats] for STAC 1.0) are not
@@ -269,6 +272,29 @@ properties of the model.
[stac-raster-stats]: https://github.com/stac-extensions/raster?tab=readme-ov-file#statistics-object
[stac-band-names]: https://github.com/stac-extensions/eo?tab=readme-ov-file#common-band-names
+#### Model Band Object
+
+| Field Name | Type | Description |
+|------------|--------|----------------------------------------------------------------------------------------------------------------------------------------|
+| name | string | **REQUIRED** Name of the band referring to an extended band definition (see [Bands](#bands-and-statistics). |
+| format | string | The type of expression that is specified in the `expression` property. |
+| expression | \* | An expression compliant with the `format` specified. The expression can be applied to any data type and depends on the `format` given. |
+
+> :information_source:
+> Although `format` and `expression` are not required in this context, they are mutually dependent on each other.
+> See also [Processing Expression](#processing-expression) for more details and examples.
+
+The `format` and `expression` properties can serve multiple purpose.
+
+1. Applying a band-specific pre-processing step,
+ in contrast to [`pre_processing_function`](#model-input-object) applied over all bands.
+ For example, reshaping a band to align its dimensions with other bands before stacking them.
+
+2. Defining a derived-band operation or a calculation that produces a virtual band from other band references.
+ For example, computing an indice that applies an arithmetic combination of other bands.
+
+For a concrete example, see [examples/item_bands_expression.json](examples/item_bands_expression.json).
+
#### Data Type Enum
When describing the `data_type` provided by a [Band](#bands-and-statistics), whether for defining
diff --git a/examples/collection.json b/examples/collection.json
index 46c78ff..aff632b 100644
--- a/examples/collection.json
+++ b/examples/collection.json
@@ -52,6 +52,10 @@
"href": "item_basic.json",
"rel": "item"
},
+ {
+ "href": "item_bands_expression.json",
+ "rel": "item"
+ },
{
"href": "item_eo_bands.json",
"rel": "item"
diff --git a/examples/item_bands_expression.json b/examples/item_bands_expression.json
new file mode 100644
index 0000000..3fdd4aa
--- /dev/null
+++ b/examples/item_bands_expression.json
@@ -0,0 +1,204 @@
+{
+ "$comment": "Demonstrate the use of MLM and EO for bands description, with EO bands directly in the Model Asset.",
+ "stac_version": "1.0.0",
+ "stac_extensions": [
+ "https://crim-ca.github.io/mlm-extension/v1.1.0/schema.json",
+ "https://stac-extensions.github.io/eo/v1.1.0/schema.json",
+ "https://stac-extensions.github.io/raster/v1.1.0/schema.json",
+ "https://stac-extensions.github.io/file/v1.0.0/schema.json",
+ "https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json"
+ ],
+ "type": "Feature",
+ "id": "resnet-18_sentinel-2_all_moco_classification",
+ "collection": "ml-model-examples",
+ "geometry": {
+ "type": "Polygon",
+ "coordinates": [
+ [
+ [
+ -7.882190080512502,
+ 37.13739173208318
+ ],
+ [
+ -7.882190080512502,
+ 58.21798141355221
+ ],
+ [
+ 27.911651652899923,
+ 58.21798141355221
+ ],
+ [
+ 27.911651652899923,
+ 37.13739173208318
+ ],
+ [
+ -7.882190080512502,
+ 37.13739173208318
+ ]
+ ]
+ ]
+ },
+ "bbox": [
+ -7.882190080512502,
+ 37.13739173208318,
+ 27.911651652899923,
+ 58.21798141355221
+ ],
+ "properties": {
+ "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO",
+ "datetime": null,
+ "start_datetime": "1900-01-01T00:00:00Z",
+ "end_datetime": "9999-12-31T23:59:59Z",
+ "mlm:name": "Resnet-18 Sentinel-2 ALL MOCO",
+ "mlm:tasks": [
+ "classification"
+ ],
+ "mlm:architecture": "ResNet",
+ "mlm:framework": "pytorch",
+ "mlm:framework_version": "2.1.2+cu121",
+ "file:size": 43000000,
+ "mlm:memory_size": 1,
+ "mlm:total_parameters": 11700000,
+ "mlm:pretrained_source": "EuroSat Sentinel-2",
+ "mlm:accelerator": "cuda",
+ "mlm:accelerator_constrained": false,
+ "mlm:accelerator_summary": "Unknown",
+ "mlm:batch_size_suggestion": 256,
+ "mlm:input": [
+ {
+ "name": "RBG+NDVI Bands Sentinel-2 Batch",
+ "bands": [
+ {
+ "name": "B04"
+ },
+ {
+ "name": "B03"
+ },
+ {
+ "name": "B02"
+ },
+ {
+ "name": "NDVI",
+ "format": "rio-calc",
+ "expression": "(B08 - B04) / (B08 + B04)"
+ }
+ ],
+ "input": {
+ "shape": [
+ -1,
+ 13,
+ 64,
+ 64
+ ],
+ "dim_order": [
+ "batch",
+ "channel",
+ "height",
+ "width"
+ ],
+ "data_type": "float32"
+ }
+ }
+ ],
+ "mlm:output": [
+ {
+ "name": "classification",
+ "tasks": [
+ "segmentation",
+ "semantic-segmentation"
+ ],
+ "result": {
+ "shape": [
+ -1,
+ 10
+ ],
+ "dim_order": [
+ "batch",
+ "class"
+ ],
+ "data_type": "float32"
+ },
+ "classification_classes": [
+ {
+ "value": 1,
+ "name": "vegetation",
+ "title": "Vegetation",
+ "description": "Pixels were vegetation is detected.",
+ "color_hint": "00FF00",
+ "nodata": false
+ },
+ {
+ "value": 0,
+ "name": "background",
+ "title": "Non-Vegetation",
+ "description": "Anything that is not classified as vegetation.",
+ "color_hint": "000000",
+ "nodata": false
+ }
+ ],
+ "post_processing_function": null
+ }
+ ]
+ },
+ "assets": {
+ "weights": {
+ "href": "https://example.com/model-rgb-ndvi.pth",
+ "title": "Pytorch weights checkpoint",
+ "description": "A vegetation classification model trained on Sentinel-2 imagery and NDVI.",
+ "type": "application/octet-stream; application=pytorch",
+ "roles": [
+ "mlm:model",
+ "mlm:weights"
+ ],
+ "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.",
+ "eo:bands": [
+ {
+ "name": "B02",
+ "common_name": "blue",
+ "description": "Blue (band 2)",
+ "center_wavelength": 0.49,
+ "full_width_half_max": 0.098
+ },
+ {
+ "name": "B03",
+ "common_name": "green",
+ "description": "Green (band 3)",
+ "center_wavelength": 0.56,
+ "full_width_half_max": 0.045
+ },
+ {
+ "name": "B04",
+ "common_name": "red",
+ "description": "Red (band 4)",
+ "center_wavelength": 0.665,
+ "full_width_half_max": 0.038
+ },
+ {
+ "name": "B08",
+ "common_name": "nir",
+ "description": "NIR 1 (band 8)",
+ "center_wavelength": 0.842,
+ "full_width_half_max": 0.145
+ }
+ ]
+ }
+ },
+ "links": [
+ {
+ "rel": "collection",
+ "href": "./collection.json",
+ "type": "application/json"
+ },
+ {
+ "rel": "self",
+ "href": "./item_bands_expression.json",
+ "type": "application/geo+json"
+ },
+ {
+ "rel": "derived_from",
+ "href": "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a",
+ "type": "application/json",
+ "ml-aoi:split": "train"
+ }
+ ]
+}
diff --git a/json-schema/schema.json b/json-schema/schema.json
index 21cd5f2..b31c73f 100644
--- a/json-schema/schema.json
+++ b/json-schema/schema.json
@@ -711,13 +711,45 @@
}
},
"ModelBands": {
+ "description": "List of bands (if any) that compose the input. Band order represents the index position of the bands.",
"allOf": [
{
"$comment": "No 'minItems' here to support model inputs not using any band (other data source).",
"type": "array",
"items": {
- "type": "string",
- "minLength": 1
+ "oneOf": [
+ {
+ "description": "Implied named-band with the name directly provided.",
+ "type": "string",
+ "minLength": 1
+ },
+ {
+ "description": "Explicit named-band with optional derived expression to obtain it.",
+ "type": "object",
+ "required": [
+ "name"
+ ],
+ "properties": {
+ "name": {
+ "type": "string",
+ "minLength": 1
+ },
+ "format": {
+ "description": "Format to interpret the specified expression used to obtain the band.",
+ "type": "string",
+ "minLength": 1
+ },
+ "expression": {
+ "description": "Any representation relevant for the specified 'format'."
+ }
+ },
+ "dependencies": {
+ "format": ["expression"],
+ "expression": ["format"]
+ },
+ "additionalProperties": false
+ }
+ ]
}
},
{
diff --git a/pyproject.toml b/pyproject.toml
index 8e52e60..c5ad032 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -224,6 +224,7 @@ select = [
[tool.ruff.lint.isort]
known-local-folder = ["tests", "conftest"]
known-first-party = ["stac_model"]
+extra-standard-library = ["typing_extensions"]
[tool.mypy]
# https://github.com/python/mypy
diff --git a/stac_model/base.py b/stac_model/base.py
index 4e8cc6b..c5ff7c6 100644
--- a/stac_model/base.py
+++ b/stac_model/base.py
@@ -116,7 +116,7 @@ class TaskEnum(str, Enum):
ModelTask = Union[ModelTaskNames, TaskEnum]
-class ProcessingExpression(BaseModel):
+class ProcessingExpression(MLMBaseModel):
# FIXME: should use 'pystac' reference, but 'processing' extension is not implemented yet!
format: str
expression: Any
diff --git a/stac_model/examples.py b/stac_model/examples.py
index c781a5b..47be1db 100644
--- a/stac_model/examples.py
+++ b/stac_model/examples.py
@@ -63,7 +63,10 @@ def eurosat_resnet() -> ItemMLModelExtension:
1231.58581042,
]
stats = [
- MLMStatistic(mean=mean, stddev=stddev)
+ MLMStatistic(
+ mean=mean,
+ stddev=stddev,
+ )
for mean, stddev in zip(stats_mean, stats_stddev)
]
model_input = ModelInput(
@@ -82,7 +85,7 @@ def eurosat_resnet() -> ItemMLModelExtension:
result_struct = ModelResult(
shape=[-1, 10],
dim_order=["batch", "class"],
- data_type="float32"
+ data_type="float32",
)
class_map = {
"Annual Crop": 0,
@@ -97,7 +100,10 @@ def eurosat_resnet() -> ItemMLModelExtension:
"SeaLake": 9,
}
class_objects = [
- MLMClassification(value=class_value, name=class_name)
+ MLMClassification(
+ value=class_value,
+ name=class_name,
+ )
for class_name, class_value in class_map.items()
]
model_output = ModelOutput(
@@ -119,8 +125,8 @@ def eurosat_resnet() -> ItemMLModelExtension:
roles=[
"mlm:model",
"mlm:weights",
- "data"
- ]
+ "data",
+ ],
),
"source_code": pystac.Asset(
title="Model implementation.",
@@ -129,9 +135,9 @@ def eurosat_resnet() -> ItemMLModelExtension:
media_type="text/x-python",
roles=[
"mlm:model",
- "code"
- ]
- )
+ "code",
+ ],
+ ),
}
ml_model_size = 43000000
@@ -163,7 +169,7 @@ def eurosat_resnet() -> ItemMLModelExtension:
-7.882190080512502,
37.13739173208318,
27.911651652899923,
- 58.21798141355221
+ 58.21798141355221,
]
geometry = shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__
item_name = "item_basic"
@@ -177,9 +183,7 @@ def eurosat_resnet() -> ItemMLModelExtension:
properties={
"start_datetime": start_datetime,
"end_datetime": end_datetime,
- "description": (
- "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO"
- ),
+ "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO",
},
assets=assets,
)
@@ -202,7 +206,7 @@ def eurosat_resnet() -> ItemMLModelExtension:
extent=pystac.Extent(
temporal=pystac.TemporalExtent([[parse_dt(start_datetime), parse_dt(end_datetime)]]),
spatial=pystac.SpatialExtent([bbox]),
- )
+ ),
)
col.set_self_href("./examples/collection.json")
col.add_item(item)
@@ -210,7 +214,7 @@ def eurosat_resnet() -> ItemMLModelExtension:
model_asset = cast(
FileExtension[pystac.Asset],
- pystac.extensions.file.FileExtension.ext(assets["model"], add_if_missing=True)
+ pystac.extensions.file.FileExtension.ext(assets["model"], add_if_missing=True),
)
model_asset.apply(size=ml_model_size)
diff --git a/stac_model/input.py b/stac_model/input.py
index 19c6e13..22788d7 100644
--- a/stac_model/input.py
+++ b/stac_model/input.py
@@ -1,6 +1,7 @@
-from typing import Annotated, List, Literal, Optional, TypeAlias, Union
+from typing import Annotated, Any, List, Literal, Optional, Sequence, TypeAlias, Union
+from typing_extensions import Self
-from pydantic import Field
+from pydantic import Field, model_validator
from stac_model.base import DataType, MLMBaseModel, Number, OmitIfNone, ProcessingExpression
@@ -10,6 +11,12 @@ class InputStructure(MLMBaseModel):
dim_order: List[str] = Field(min_items=1)
data_type: DataType
+ @model_validator(mode="after")
+ def validate_dimensions(self) -> Self:
+ if len(self.shape) != len(self.dim_order):
+ raise ValueError("Dimension order and shape must be of equal length for corresponding indices.")
+ return self
+
class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster extension (cases required to be triggered)
minimum: Annotated[Optional[Number], OmitIfNone] = None
@@ -31,7 +38,7 @@ class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster ext
"hamming2",
"type-mask",
"relative",
- "inf"
+ "inf",
]
]
@@ -51,9 +58,54 @@ class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster ext
]
+class ModelBand(MLMBaseModel):
+ name: str = Field(
+ description=(
+ "Name of the band to use for the input, "
+ "referring to the name of an entry in a 'bands' definition from another STAC extension."
+ )
+ )
+ # similar to 'ProcessingExpression', but they can be omitted here
+ format: Annotated[Optional[str], OmitIfNone] = Field(
+ default=None,
+ description="",
+ )
+ expression: Annotated[Optional[Any], OmitIfNone] = Field(
+ default=None,
+ description="",
+ )
+
+ @model_validator(mode="after")
+ def validate_expression(self) -> Self:
+ if ( # mutually dependant
+ (self.format is not None or self.expression is not None)
+ and (self.format is None or self.expression is None)
+ ):
+ raise ValueError("Model band 'format' and 'expression' are mutually dependant.")
+ return self
+
+
class ModelInput(MLMBaseModel):
name: str
- bands: List[str] # order is critical here (same index as dim shape), allow duplicate if the model needs it somehow
+ # order is critical here (same index as dim shape), allow duplicate if the model needs it somehow
+ bands: Sequence[Union[str, ModelBand]] = Field(
+ description=(
+ "List of bands that compose the input. "
+ "If a string is used, it is implied to correspond to a named-band. "
+ "If no band is needed for the input, use an empty array."
+ ),
+ examples=[
+ [
+ "B01",
+ {"name": "B02"},
+ {
+ "name": "NDVI",
+ "format": "rio-calc",
+ "expression": "(B08 - B04) / (B08 + B04)",
+ },
+ ],
+ ],
+ )
input: InputStructure
norm_by_channel: Annotated[Optional[bool], OmitIfNone] = None
norm_type: Annotated[Optional[NormalizeType], OmitIfNone] = None
diff --git a/stac_model/schema.py b/stac_model/schema.py
index a9c8146..db7ae37 100644
--- a/stac_model/schema.py
+++ b/stac_model/schema.py
@@ -97,15 +97,18 @@ def get_schema_uri(cls) -> str:
@overload
@classmethod
- def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> "AssetMLModelExtension": ...
+ def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> "AssetMLModelExtension":
+ ...
@overload
@classmethod
- def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> "ItemMLModelExtension": ...
+ def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> "ItemMLModelExtension":
+ ...
@overload
@classmethod
- def ext(cls, obj: pystac.Collection, add_if_missing: bool = False) -> "CollectionMLModelExtension": ...
+ def ext(cls, obj: pystac.Collection, add_if_missing: bool = False) -> "CollectionMLModelExtension":
+ ...
# @overload
# @classmethod
diff --git a/tests/conftest.py b/tests/conftest.py
index 75c81b5..417a913 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,7 +22,8 @@ def get_all_stac_item_examples() -> List[str]:
all_json = glob.glob("**/*.json", root_dir=EXAMPLES_DIR, recursive=True)
all_geojson = glob.glob("**/*.geojson", root_dir=EXAMPLES_DIR, recursive=True)
all_stac_items = [
- path for path in all_json + all_geojson
+ path
+ for path in all_json + all_geojson
if os.path.splitext(os.path.basename(path))[0] not in ["collection", "catalog"]
]
return all_stac_items
diff --git a/tests/test_stac_model.py b/tests/test_stac_model.py
new file mode 100644
index 0000000..83d169c
--- /dev/null
+++ b/tests/test_stac_model.py
@@ -0,0 +1,52 @@
+import pydantic
+import pytest
+
+from stac_model.input import InputStructure, ModelBand, ModelInput
+
+
+@pytest.mark.parametrize(
+ "bands",
+ [
+ ["B04", "B03", "B02"],
+ [{"name": "B04"}, {"name": "B03"}, {"name": "B02"}],
+ [{"name": "NDVI", "format": "rio-calc", "expression": "(B08 - B04) / (B08 + B04)"}],
+ [
+ "B04",
+ {"name": "B03"},
+ "B02",
+ {"name": "NDVI", "format": "rio-calc", "expression": "(B08 - B04) / (B08 + B04)"},
+ ],
+ ],
+)
+def test_model_band(bands):
+ mlm_input = ModelInput(
+ name="test",
+ bands=bands,
+ input=InputStructure(
+ shape=[-1, len(bands), 64, 64],
+ dim_order=["batch", "channel", "height", "width"],
+ data_type="float32",
+ ),
+ )
+ mlm_bands = mlm_input.dict()["bands"]
+ assert mlm_bands == bands
+
+
+@pytest.mark.parametrize(
+ "bands",
+ [
+ [{"name": "test", "expression": "missing-format"}],
+ [{"name": "test", "format": "missing-expression"}],
+ ],
+)
+def test_model_band_format_expression_dependency(bands: list[ModelBand]) -> None:
+ with pytest.raises(pydantic.ValidationError):
+ ModelInput(
+ name="test",
+ bands=bands,
+ input=InputStructure(
+ shape=[-1, len(bands), 64, 64],
+ dim_order=["batch", "channel", "height", "width"],
+ data_type="float32",
+ ),
+ )