diff --git a/stac_model/examples.py b/stac_model/examples.py index 9747086..aaeefa5 100644 --- a/stac_model/examples.py +++ b/stac_model/examples.py @@ -1,6 +1,9 @@ import pystac import json import shapely +from dateutil.parser import parse as parse_dt +from pystac import media_type + from stac_model.base import ProcessingExpression from stac_model.input import ModelInput from stac_model.output import ModelOutput, ModelResult @@ -86,20 +89,6 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]: expression="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn" ), # noqa: E501 ) - # runtime = Runtime( - # framework="torch", - # version="2.1.2+cu121", - # asset=Asset(title = "Pytorch weights checkpoint", description="A Resnet-18 classification model trained on normalized Sentinel-2 imagery with Eurosat landcover labels with torchgeo", # noqa: E501 - # type=".pth", roles=["weights"], href="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth" # noqa: E501 - # ), - # source_code=Asset( - # href="https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207" # noqa: E501 - # ), - # accelerator="cuda", - # accelerator_constrained=False, - # hardware_summary="Unknown", - # commit_hash="61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a", - # ) result_array = ModelResult( shape=[-1, 10], dim_order=["batch", "class"], @@ -128,11 +117,38 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]: result=result_array, post_processing_function=None, ) + assets = { + "model": pystac.Asset( + title="Pytorch weights checkpoint", + description=( + "A Resnet-18 classification model trained on normalized Sentinel-2 " + "imagery with Eurosat landcover labels with torchgeo." + ), + href="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth", + media_type="application/octet-stream; application=pytorch", + roles=[ + "mlm:model", + "mlm:weights", + "data" + ] + ), + "source_code": pystac.Asset( + href="https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207", + media_type="text/x-python", + roles=[ + "mlm:model", + "code" + ] + ) + } ml_model_meta = MLModelProperties( name="Resnet-18 Sentinel-2 ALL MOCO", tasks={"classification"}, framework="pytorch", framework_version="2.1.2+cu121", + accelerator="cuda", + accelerator_constrained=False, + accelerator_summary="Unknown", file_size=43000000, memory_size=1, pretrained_source="EuroSat Sentinel-2", @@ -144,7 +160,7 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]: # in docs. start_datetime=datetime.strptime("1900-01-01", "%Y-%m-%d") # Is this a problem that we don't do date validation if we supply as str? start_datetime = "1900-01-01" - end_datetime = None + end_datetime = "9999-01-01" # cannot be None, invalid against STAC Core! bbox = [ -7.882190080512502, 37.13739173208318, @@ -159,13 +175,14 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]: bbox=bbox, datetime=None, properties={ - "start_datetime": start_datetime, - "end_datetime": end_datetime, + "start_datetime": parse_dt(start_datetime).isoformat() + "Z", + "end_datetime": parse_dt(end_datetime).isoformat() + "Z", "description": ( "Sourced from torchgeo python library," "identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO" ), }, + assets=assets, ) item.add_derived_from( "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a" diff --git a/stac_model/runtime.py b/stac_model/runtime.py index c0a685b..bf38313 100644 --- a/stac_model/runtime.py +++ b/stac_model/runtime.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional +from typing import List, Literal, Optional, Union from pydantic import AnyUrl, BaseModel, ConfigDict, FilePath, Field @@ -40,6 +40,19 @@ def __str__(self): return self.value +AcceleratorName = Literal[ + "amd64", + "cuda", + "xla", + "amd-rocm", + "intel-ipex-cpu", + "intel-ipex-gpu", + "macos-arm", +] + +AcceleratorType = Union[AcceleratorName, AcceleratorEnum] + + class Runtime(BaseModel): framework: str = Field(default="", exclude_defaults=True, exclude_unset=True) framework_version: str = Field(default="", exclude_defaults=True, exclude_unset=True) @@ -47,7 +60,7 @@ class Runtime(BaseModel): memory_size: int = Field(default=0, exclude_defaults=True, exclude_unset=True) batch_size_suggestion: Optional[int] = Field(default=None, exclude_defaults=True, exclude_unset=True) - accelerator: Optional[AcceleratorEnum] = Field(exclude_unset=True, default=None) + accelerator: Optional[AcceleratorType] = Field(exclude_unset=True, default=None) accelerator_constrained: bool = Field(exclude_unset=True, default=False) accelerator_summary: str = Field(exclude_unset=True, exclude_defaults=True, default="") accelerator_count: int = Field(minimum=1, exclude_unset=True, exclude_defaults=True, default=-1) diff --git a/tests/test_schema.py b/tests/test_schema.py index b21b0e2..e720cc7 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -31,3 +31,9 @@ def test_model_metadata_to_dict(eurosat_resnet): def test_validate_model_metadata(eurosat_resnet): assert pystac.read_dict(eurosat_resnet.item.to_dict()) + + +def test_validate_model_against_schema(eurosat_resnet, mlm_validator): + mlm_item = pystac.read_dict(eurosat_resnet.item.to_dict()) + validated = pystac.validation.validate(mlm_item, validator=mlm_validator) + assert SCHEMA_URI in validated