Skip to content
This repository has been archived by the owner on Oct 2, 2024. It is now read-only.

Commit

Permalink
adjust pydantic eurosat_example with json-schema fields
Browse files Browse the repository at this point in the history
  • Loading branch information
fmigneault committed Apr 5, 2024
1 parent 2d6c70b commit d111678
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 19 deletions.
51 changes: 34 additions & 17 deletions stac_model/examples.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pystac
import json
import shapely
from dateutil.parser import parse as parse_dt
from pystac import media_type

from stac_model.base import ProcessingExpression
from stac_model.input import ModelInput
from stac_model.output import ModelOutput, ModelResult
Expand Down Expand Up @@ -86,20 +89,6 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
expression="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn"
), # noqa: E501
)
# runtime = Runtime(
# framework="torch",
# version="2.1.2+cu121",
# asset=Asset(title = "Pytorch weights checkpoint", description="A Resnet-18 classification model trained on normalized Sentinel-2 imagery with Eurosat landcover labels with torchgeo", # noqa: E501
# type=".pth", roles=["weights"], href="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth" # noqa: E501
# ),
# source_code=Asset(
# href="https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207" # noqa: E501
# ),
# accelerator="cuda",
# accelerator_constrained=False,
# hardware_summary="Unknown",
# commit_hash="61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a",
# )
result_array = ModelResult(
shape=[-1, 10],
dim_order=["batch", "class"],
Expand Down Expand Up @@ -128,11 +117,38 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
result=result_array,
post_processing_function=None,
)
assets = {
"model": pystac.Asset(
title="Pytorch weights checkpoint",
description=(
"A Resnet-18 classification model trained on normalized Sentinel-2 "
"imagery with Eurosat landcover labels with torchgeo."
),
href="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth",
media_type="application/octet-stream; application=pytorch",
roles=[
"mlm:model",
"mlm:weights",
"data"
]
),
"source_code": pystac.Asset(
href="https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207",
media_type="text/x-python",
roles=[
"mlm:model",
"code"
]
)
}
ml_model_meta = MLModelProperties(
name="Resnet-18 Sentinel-2 ALL MOCO",
tasks={"classification"},
framework="pytorch",
framework_version="2.1.2+cu121",
accelerator="cuda",
accelerator_constrained=False,
accelerator_summary="Unknown",
file_size=43000000,
memory_size=1,
pretrained_source="EuroSat Sentinel-2",
Expand All @@ -144,7 +160,7 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
# in docs. start_datetime=datetime.strptime("1900-01-01", "%Y-%m-%d")
# Is this a problem that we don't do date validation if we supply as str?
start_datetime = "1900-01-01"
end_datetime = None
end_datetime = "9999-01-01" # cannot be None, invalid against STAC Core!
bbox = [
-7.882190080512502,
37.13739173208318,
Expand All @@ -159,13 +175,14 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
bbox=bbox,
datetime=None,
properties={
"start_datetime": start_datetime,
"end_datetime": end_datetime,
"start_datetime": parse_dt(start_datetime).isoformat() + "Z",
"end_datetime": parse_dt(end_datetime).isoformat() + "Z",
"description": (
"Sourced from torchgeo python library,"
"identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO"
),
},
assets=assets,
)
item.add_derived_from(
"https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a"
Expand Down
17 changes: 15 additions & 2 deletions stac_model/runtime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Optional
from typing import List, Literal, Optional, Union

from pydantic import AnyUrl, BaseModel, ConfigDict, FilePath, Field

Expand Down Expand Up @@ -40,14 +40,27 @@ def __str__(self):
return self.value


AcceleratorName = Literal[
"amd64",
"cuda",
"xla",
"amd-rocm",
"intel-ipex-cpu",
"intel-ipex-gpu",
"macos-arm",
]

AcceleratorType = Union[AcceleratorName, AcceleratorEnum]


class Runtime(BaseModel):
framework: str = Field(default="", exclude_defaults=True, exclude_unset=True)
framework_version: str = Field(default="", exclude_defaults=True, exclude_unset=True)
file_size: int = Field(alias="file:size", default=0, exclude_defaults=True, exclude_unset=True)
memory_size: int = Field(default=0, exclude_defaults=True, exclude_unset=True)
batch_size_suggestion: Optional[int] = Field(default=None, exclude_defaults=True, exclude_unset=True)

accelerator: Optional[AcceleratorEnum] = Field(exclude_unset=True, default=None)
accelerator: Optional[AcceleratorType] = Field(exclude_unset=True, default=None)
accelerator_constrained: bool = Field(exclude_unset=True, default=False)
accelerator_summary: str = Field(exclude_unset=True, exclude_defaults=True, default="")
accelerator_count: int = Field(minimum=1, exclude_unset=True, exclude_defaults=True, default=-1)
6 changes: 6 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ def test_model_metadata_to_dict(eurosat_resnet):

def test_validate_model_metadata(eurosat_resnet):
assert pystac.read_dict(eurosat_resnet.item.to_dict())


def test_validate_model_against_schema(eurosat_resnet, mlm_validator):
mlm_item = pystac.read_dict(eurosat_resnet.item.to_dict())
validated = pystac.validation.validate(mlm_item, validator=mlm_validator)
assert SCHEMA_URI in validated

0 comments on commit d111678

Please sign in to comment.