From 2d6c70b63738e62908eef432ad046f186835ed58 Mon Sep 17 00:00:00 2001 From: Francis Charette-Migneault Date: Thu, 4 Apr 2024 20:52:00 -0400 Subject: [PATCH] update examples working against JSON schema (except check for cross-bands [AnyBandsRef]) --- README.md | 2 +- examples/item_basic.json | 116 +++++++++ ...ample_eo_bands.json => item_eo_bands.json} | 3 +- examples/item_multi_io.json | 242 ++++++++++++++++++ .../{example.json => item_raster_bands.json} | 68 +---- json-schema/schema.json | 57 ++++- stac_model/input.py | 2 +- stac_model/runtime.py | 10 +- tests/test_schema.py | 6 +- 9 files changed, 428 insertions(+), 78 deletions(-) create mode 100644 examples/item_basic.json rename examples/{example_eo_bands.json => item_eo_bands.json} (99%) create mode 100644 examples/item_multi_io.json rename examples/{example.json => item_raster_bands.json} (86%) diff --git a/README.md b/README.md index cf595df..78e4553 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ The fields in the table below can be used in these parts of STAC documents: | mlm:name | string | **REQUIRED** A unique name for the model. This can include, but must be distinct, from simply naming the model architecture. If there is a publication or other published work related to the model, use the official name of the model. | | mlm:architecture | [Model Architecture](#model-architecture) string | **REQUIRED** A generic and well established architecture name of the model. | | mlm:tasks | \[[Task Enum](#task-enum)] | **REQUIRED** Specifies the Machine Learning tasks for which the model can be used for. If multi-tasks outputs are provided by distinct model heads, specify all available tasks under the main properties and specify respective tasks in each [Model Output Object](#model-output-object). | -| mlm:framework | string | **REQUIRED** Framework used to train the model (ex: PyTorch, TensorFlow). | +| mlm:framework | string | Framework used to train the model (ex: PyTorch, TensorFlow). | | mlm:framework_version | string | The `framework` library version. Some models require a specific version of the machine learning `framework` to run. | | mlm:memory_size | integer | The in-memory size of the model on the accelerator during inference (bytes). | | mlm:total_parameters | integer | Total number of model parameters, including trainable and non-trainable parameters. | diff --git a/examples/item_basic.json b/examples/item_basic.json new file mode 100644 index 0000000..0778163 --- /dev/null +++ b/examples/item_basic.json @@ -0,0 +1,116 @@ +{ + "stac_version": "1.0.0", + "stac_extensions": [ + "https://stac-extensions.github.io/mlm/v1.0.0/schema.json" + ], + "type": "Feature", + "id": "example-model", + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + -7.882190080512502, + 37.13739173208318 + ], + [ + -7.882190080512502, + 58.21798141355221 + ], + [ + 27.911651652899925, + 58.21798141355221 + ], + [ + 27.911651652899925, + 37.13739173208318 + ], + [ + -7.882190080512502, + 37.13739173208318 + ] + ] + ] + }, + "bbox": [ + -7.882190080512502, + 37.13739173208318, + 27.911651652899923, + 58.21798141355221 + ], + "properties": { + "datetime": null, + "start_datetime": "1900-01-01T00:00:00Z", + "end_datetime": "9999-12-31T23:59:59Z", + "mlm:name": "example-model", + "mlm:tasks": [ + "classification" + ], + "mlm:architecture": "ResNet", + "mlm:input": [ + { + "name": "Model with RGB input that does not refer to any band.", + "bands": [], + "input": { + "shape": [ + -1, + 3, + 64, + 64 + ], + "dim_order": [ + "batch", + "channel", + "height", + "width" + ], + "data_type": "float32" + } + } + ], + "mlm:output": [ + { + "name": "classification", + "tasks": [ + "classification" + ], + "result": { + "shape": [ + -1, + 1 + ], + "dim_order": [ + "batch", + "class" + ], + "data_type": "uint8" + }, + "classification_classes": [ + { + "value": 0, + "name": "BACKGROUND", + "description": "Background non-city.", + "color_hint": [0, 0, 0] + }, + { + "value": 1, + "name": "CITY", + "description": "A city is detected.", + "color_hint": [0, 0, 255] + } + ] + } + ] + }, + "assets": { + "model": { + "href": "https://huggingface.co/example/model-card", + "title": "Pytorch weights checkpoint", + "description": "Example model.", + "type": "text/html", + "roles": [ + "mlm:model" + ] + } + } +} diff --git a/examples/example_eo_bands.json b/examples/item_eo_bands.json similarity index 99% rename from examples/example_eo_bands.json rename to examples/item_eo_bands.json index adb29d2..60a5868 100644 --- a/examples/example_eo_bands.json +++ b/examples/item_eo_bands.json @@ -51,6 +51,7 @@ "mlm:tasks": [ "classification" ], + "mlm:architecture": "ResNet", "mlm:framework": "pytorch", "mlm:framework_version": "2.1.2+cu121", "file:size": 43000000, @@ -60,7 +61,7 @@ "mlm:accelerator": "cuda", "mlm:accelerator_constrained": false, "mlm:accelerator_summary": "Unknown", - "mlm:batch_size_suggestion": null, + "mlm:batch_size_suggestion": 256, "mlm:input": [ { "name": "13 Band Sentinel-2 Batch", diff --git a/examples/item_multi_io.json b/examples/item_multi_io.json new file mode 100644 index 0000000..cd1b465 --- /dev/null +++ b/examples/item_multi_io.json @@ -0,0 +1,242 @@ +{ + "stac_version": "1.0.0", + "stac_extensions": [ + "https://stac-extensions.github.io/mlm/v1.0.0/schema.json", + "https://stac-extensions.github.io/raster/v1.1.0/schema.json", + "https://stac-extensions.github.io/file/v1.0.0/schema.json", + "https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json" + ], + "type": "Feature", + "id": "resnet-18_sentinel-2_all_moco_classification", + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + -7.882190080512502, + 37.13739173208318 + ], + [ + -7.882190080512502, + 58.21798141355221 + ], + [ + 27.911651652899925, + 58.21798141355221 + ], + [ + 27.911651652899925, + 37.13739173208318 + ], + [ + -7.882190080512502, + 37.13739173208318 + ] + ] + ] + }, + "bbox": [ + -7.882190080512502, + 37.13739173208318, + 27.911651652899923, + 58.21798141355221 + ], + "properties": { + "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO", + "datetime": null, + "start_datetime": "1900-01-01T00:00:00Z", + "end_datetime": "9999-12-31T23:59:59Z", + "mlm:name": "Resnet-18 Sentinel-2 ALL MOCO", + "mlm:tasks": [ + "classification" + ], + "mlm:architecture": "ResNet", + "mlm:framework": "pytorch", + "mlm:framework_version": "2.1.2+cu121", + "file:size": 43000000, + "mlm:memory_size": 1, + "mlm:total_parameters": 11700000, + "mlm:pretrained_source": "EuroSat Sentinel-2", + "mlm:accelerator": "cuda", + "mlm:accelerator_constrained": false, + "mlm:accelerator_summary": "Unknown", + "mlm:batch_size_suggestion": 256, + "mlm:input": [ + { + "name": "RGB", + "bands": [ + "B04", + "B03", + "B02" + ], + "input": { + "shape": [ + -1, + 3, + 64, + 64 + ], + "dim_order": [ + "batch", + "channel", + "height", + "width" + ], + "data_type": "uint16" + }, + "norm_by_channel": false, + "norm_type": null, + "resize_type": null + }, + { + "name": "NDVI", + "bands": [ + "B04", + "B08" + ], + "pre_processing_function": { + "format": "gdal-calc", + "expression": "(A - B) / (A + B)" + }, + "input": { + "shape": [ + -1, + 1, + 64, + 64 + ], + "dim_order": [ + "batch", + "ndvi", + "height", + "width" + ], + "data_type": "uint16" + } + } + ], + "mlm:output": [ + { + "name": "vegetation-segmentation", + "tasks": [ + "semantic-segmentation" + ], + "result": { + "shape": [ + -1, + 1 + ], + "dim_order": [ + "batch", + "class" + ], + "data_type": "uint8" + }, + "classification_classes": [ + { + "value": 0, + "name": "NON_VEGETATION", + "description": "background pixels", + "color_hint": null + }, + { + "value": 1, + "name": "VEGETATION", + "description": "pixels where vegetation was detected", + "color_hint": [0, 255, 0] + } + ], + "post_processing_function": null + }, + { + "name": "inverse-mask", + "tasks": [ + "semantic-segmentation" + ], + "result": { + "shape": [ + -1, + 1 + ], + "dim_order": [ + "batch", + "class" + ], + "data_type": "uint8" + }, + "classification_classes": [ + { + "value": 0, + "name": "NON_VEGETATION", + "description": "background pixels", + "color_hint": [255, 255, 255] + }, + { + "value": 1, + "name": "VEGETATION", + "description": "pixels where vegetation was detected", + "color_hint": [0, 0, 0] + } + ], + "post_processing_function": { + "format": "gdal-calc", + "expression": "logical_not(A)" + } + } + ], + "raster:bands": [ + { + "name": "B02 - blue", + "nodata": 0, + "data_type": "uint16", + "bits_per_sample": 15, + "spatial_resolution": 10, + "scale": 0.0001, + "offset": 0, + "unit": "m" + }, + { + "name": "B03 - green", + "nodata": 0, + "data_type": "uint16", + "bits_per_sample": 15, + "spatial_resolution": 10, + "scale": 0.0001, + "offset": 0, + "unit": "m" + }, + { + "name": "B04 - red", + "nodata": 0, + "data_type": "uint16", + "bits_per_sample": 15, + "spatial_resolution": 10, + "scale": 0.0001, + "offset": 0, + "unit": "m" + }, + { + "name": "B08 - nir", + "nodata": 0, + "data_type": "uint16", + "bits_per_sample": 15, + "spatial_resolution": 10, + "scale": 0.0001, + "offset": 0, + "unit": "m" + } + ] + }, + "assets": { + "weights": { + "href": "https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_moco/blob/main/resnet50_sentinel2_rgb_moco.pth", + "title": "Pytorch weights checkpoint", + "description": "A Resnet-50 classification model trained on Sentinel-2 RGB imagery with torchgeo.", + "type": "application/octet-stream; application=pytorch", + "roles": [ + "mlm:model", + "mlm:weights" + ] + } + } +} diff --git a/examples/example.json b/examples/item_raster_bands.json similarity index 86% rename from examples/example.json rename to examples/item_raster_bands.json index cb3a41a..1514819 100644 --- a/examples/example.json +++ b/examples/item_raster_bands.json @@ -50,6 +50,7 @@ "mlm:tasks": [ "classification" ], + "mlm:architecture": "ResNet", "mlm:framework": "pytorch", "mlm:framework_version": "2.1.2+cu121", "file:size": 43000000, @@ -59,7 +60,7 @@ "mlm:accelerator": "cuda", "mlm:accelerator_constrained": false, "mlm:accelerator_summary": "Unknown", - "mlm:batch_size_suggestion": null, + "mlm:batch_size_suggestion": 256, "mlm:input": [ { "name": "13 Band Sentinel-2 Batch", @@ -93,47 +94,8 @@ ], "data_type": "float32" }, - "norm_by_channel": true, - "norm_type": "z-score", + "norm_type": null, "resize_type": null, - "parameters": null, - "statistics": { - "minimum": null, - "maximum": null, - "mean": [ - 1354.40546513, - 1118.24399958, - 1042.92983953, - 947.62620298, - 1199.47283961, - 1999.79090914, - 2369.22292565, - 2296.82608323, - 732.08340178, - 12.11327804, - 1819.01027855, - 1118.92391149, - 2594.14080798 - ], - "stddev": [ - 245.71762908, - 333.00778264, - 395.09249139, - 593.75055589, - 566.4170017, - 861.18399006, - 1086.63139075, - 1117.98170791, - 404.91978886, - 4.77584468, - 1002.58768311, - 761.30323499, - 1231.58581042 - ], - "count": null, - "valid_percent": null - }, - "norm_with_clip_values": null, "pre_processing_function": { "format": "python", "expression": "torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn" @@ -146,19 +108,17 @@ "tasks": [ "classification" ], - "result": [ - { - "shape": [ - -1, - 10 - ], - "dim_order": [ - "batch", - "class" - ], - "data_type": "float32" - } - ], + "result": { + "shape": [ + -1, + 10 + ], + "dim_order": [ + "batch", + "class" + ], + "data_type": "float32" + }, "classification_classes": [ { "value": 0, diff --git a/json-schema/schema.json b/json-schema/schema.json index af3af2a..014e0a5 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -24,7 +24,6 @@ "required": [ "mlm:name", "mlm:architecture", - "mlm:framework", "mlm:tasks", "mlm:input", "mlm:output" @@ -208,7 +207,7 @@ }, "mlm:name": { "type": "string", - "pattern": "^[a-zA-Z][a-zA-Z0-9_.-]+[a-zA-Z0-9]$" + "pattern": "^[a-zA-Z][a-zA-Z0-9_.\\-\\s]+[a-zA-Z0-9]$" }, "mlm:architecture": { "type": "string", @@ -605,16 +604,12 @@ "$ref": "https://stac-extensions.github.io/raster/v1.1.0/schema.json#/definitions/bands/items/properties/data_type" }, "AssetModelRole": { - "required": ["assets"], + "required": ["roles"], "properties": { - "assets": { - "additionalProperties": { - "required": ["roles"], - "properties": { - "roles": { - "contains": "mlm:model" - } - } + "roles": { + "contains": { + "type": "string", + "const": "mlm:model" } } } @@ -631,7 +626,7 @@ }, { "$comment": "However, if any band is indicated, a 'bands'-compliant section should describe them.", - "$ref": "#/$defs/AnyBandsRef" + "FIXME_$ref": "#/$defs/AnyBandsRef" } ] }, @@ -658,10 +653,10 @@ "properties": { "bands": { "type": "array", + "minItems": 1, "items": { "type": "string", - "$comment": "This 'minItems' is the purpose of this whole 'if/then' block.", - "minItems": 1 + "$comment": "This 'minItems' is the purpose of this whole 'if/then' block." } } } @@ -756,6 +751,40 @@ ] } ] + }, + "else": { + "$comment": "This is the JSON-object 'properties' definition.", + "properties": { + "$comment": "This is the STAC-Item 'properties' field.", + "properties": { + "required": [ + "mlm:input" + ], + "$comment": "This is the JSON-object 'properties' definition for the STAC Item 'properties' field.", + "properties": { + "$comment": "Required MLM bands listing referring to at least one band name.", + "mlm:input": { + "type": "array", + "items": { + "$comment": "This is the 'Model Input Object' properties.", + "properties": { + "bands": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "array", + "maxItems": 0 + } + ] + } + } + } + } + } + } + } } } } diff --git a/stac_model/input.py b/stac_model/input.py index 107fc5c..680c603 100644 --- a/stac_model/input.py +++ b/stac_model/input.py @@ -57,7 +57,7 @@ class Band(BaseModel): class ModelInput(BaseModel): name: str - bands: List[str] + bands: List[str] # order is critical here (same index as dim shape), allow duplicate if the model needs it somehow input: InputArray norm_by_channel: bool = None norm_type: NormalizeType = None diff --git a/stac_model/runtime.py b/stac_model/runtime.py index 1c0491f..c0a685b 100644 --- a/stac_model/runtime.py +++ b/stac_model/runtime.py @@ -41,11 +41,11 @@ def __str__(self): class Runtime(BaseModel): - framework: str - framework_version: str - file_size: int = Field(alias="file:size") - memory_size: int - batch_size_suggestion: Optional[int] = None + framework: str = Field(default="", exclude_defaults=True, exclude_unset=True) + framework_version: str = Field(default="", exclude_defaults=True, exclude_unset=True) + file_size: int = Field(alias="file:size", default=0, exclude_defaults=True, exclude_unset=True) + memory_size: int = Field(default=0, exclude_defaults=True, exclude_unset=True) + batch_size_suggestion: Optional[int] = Field(default=None, exclude_defaults=True, exclude_unset=True) accelerator: Optional[AcceleratorEnum] = Field(exclude_unset=True, default=None) accelerator_constrained: bool = Field(exclude_unset=True, default=False) diff --git a/tests/test_schema.py b/tests/test_schema.py index 61f717a..b21b0e2 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -8,8 +8,10 @@ @pytest.mark.parametrize( "mlm_example", # value passed to 'mlm_example' fixture [ - "example.json", - "example_eo_bands.json", + "item_basic.json", + "item_raster_bands.json", + "item_eo_bands.json", + "item_multi_io.json", ], indirect=True, )