Skip to content
This repository has been archived by the owner on Oct 2, 2024. It is now read-only.

Commit

Permalink
fix pydantic drop unset fields as intended
Browse files Browse the repository at this point in the history
  • Loading branch information
fmigneault committed Apr 5, 2024
1 parent d111678 commit 4d57e41
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 143 deletions.
6 changes: 5 additions & 1 deletion json-schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,11 @@
},
"InputStatistics": {
"$comment": "MLM statistics for the specific input relevant for normalization for ML features.",
"$ref": "https://stac-extensions.github.io/raster/v1.1.0/schema.json#/definitions/bands/items/properties/statistics"
"type": "array",
"minItems": 1,
"items": {
"$ref": "https://stac-extensions.github.io/raster/v1.1.0/schema.json#/definitions/bands/items/properties/statistics"
}
},
"ProcessingExpression": {
"oneOf": [
Expand Down
39 changes: 38 additions & 1 deletion stac_model/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,44 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Union, TypeAlias

from pydantic import BaseModel
from pydantic import BaseModel, model_serializer


@dataclass
class _OmitIfNone:
pass


OmitIfNone = _OmitIfNone()


class MLMBaseModel(BaseModel):
"""
Allows wrapping any field with an annotation to drop it entirely if unset.
```
field: Annotated[Optional[<desiredType>], OmitIfNone] = None
# or
field: Annotated[<desiredType>, OmitIfNone] = None
# or
field: Annotated[<desiredType>, OmitIfNone] = Field(default=None)
```
It is important to use `MLMBaseModel`, otherwise the serializer will not be called and applied.
"""
@model_serializer
def model_serialize(self):
omit_if_none_fields = {
key: field
for key, field in self.model_fields.items()
if any(isinstance(m, _OmitIfNone) for m in field.metadata)
}
values = {
self.__fields__[key].alias or key: val # use the alias if specified
for key, val in self if key not in omit_if_none_fields or val is not None
}
return values


DataType: TypeAlias = Literal[
Expand Down
130 changes: 69 additions & 61 deletions stac_model/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,18 @@
import json
import shapely
from dateutil.parser import parse as parse_dt
from pystac import media_type
from typing import cast

from pystac.extensions.file import FileExtension

from stac_model.base import ProcessingExpression
from stac_model.input import ModelInput
from stac_model.output import ModelOutput, ModelResult
from stac_model.schema import (
Asset,
InputArray,
MLMClassification,
MLModelExtension,
MLModelProperties,
Runtime,
Statistics,
)
from stac_model.input import ModelInput, InputStructure, MLMStatistic
from stac_model.output import ModelOutput, ModelResult, MLMClassification
from stac_model.schema import MLModelExtension, MLModelProperties


def eurosat_resnet() -> MLModelExtension[pystac.Item]:
input_array = InputArray(
input_array = InputStructure(
shape=[-1, 13, 64, 64],
dim_order=[
"batch",
Expand All @@ -29,53 +23,56 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
],
data_type="float32",
)
band_names = [
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
band_names = []
# band_names = [
# "B01",
# "B02",
# "B03",
# "B04",
# "B05",
# "B06",
# "B07",
# "B08",
# "B8A",
# "B09",
# "B10",
# "B11",
# "B12",
# ]
stats_mean = [
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798,
]
stats_stddev = [
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042,
]
stats = [
MLMStatistic(mean=mean, stddev=stddev)
for mean, stddev in zip(stats_mean, stats_stddev)
]
stats = Statistics(
mean=[
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798,
],
stddev=[
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042,
],
)
input = ModelInput(
name="13 Band Sentinel-2 Batch",
bands=band_names,
Expand Down Expand Up @@ -141,16 +138,20 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
]
)
}

ml_model_size = 43000000
ml_model_meta = MLModelProperties(
name="Resnet-18 Sentinel-2 ALL MOCO",
architecture="ResNet-18",
tasks={"classification"},
framework="pytorch",
framework_version="2.1.2+cu121",
accelerator="cuda",
accelerator_constrained=False,
accelerator_summary="Unknown",
file_size=43000000,
file_size=ml_model_size,
memory_size=1,
pretrained=True,
pretrained_source="EuroSat Sentinel-2",
total_parameters=11_700_000,
input=[input],
Expand Down Expand Up @@ -187,6 +188,13 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]:
item.add_derived_from(
"https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a"
)

model_asset = cast(
FileExtension[pystac.Asset],
pystac.extensions.file.FileExtension.ext(assets["model"], add_if_missing=True)
)
model_asset.apply(size=ml_model_size)

item_mlm = MLModelExtension.ext(item, add_if_missing=True)
item_mlm.apply(ml_model_meta.model_dump(by_alias=True))
item_mlm.apply(ml_model_meta.model_dump(by_alias=True, exclude_unset=True, exclude_defaults=True))
return item_mlm
51 changes: 23 additions & 28 deletions stac_model/input.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
from typing import Any, List, Literal, Optional, Set, TypeAlias, Union
from typing import Any, Annotated, List, Literal, Optional, Set, TypeAlias, Union

from pydantic import BaseModel, Field
from pystac.extensions.raster import Statistics
from pydantic import ConfigDict, Field, model_serializer

from stac_model.base import DataType, ProcessingExpression
from stac_model.base import DataType, MLMBaseModel, ProcessingExpression, OmitIfNone


class InputArray(BaseModel):
shape: List[Union[int, float]] = Field(..., min_items=1)
dim_order: List[str] = Field(..., min_items=1)
data_type: DataType
Number: TypeAlias = Union[int, float]


class Statistics(BaseModel):
minimum: Optional[List[Union[float, int]]] = None
maximum: Optional[List[Union[float, int]]] = None
mean: Optional[List[float]] = None
stddev: Optional[List[float]] = None
count: Optional[List[int]] = None
valid_percent: Optional[List[float]] = None
class InputStructure(MLMBaseModel):
shape: List[Union[int, float]] = Field(min_items=1)
dim_order: List[str] = Field(min_items=1)
data_type: DataType


class Band(BaseModel):
name: str
description: Optional[str] = None
nodata: Union[float, int, str]
data_type: str
unit: Optional[str] = None
class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster extension (cases required to be triggered)
minimum: Annotated[Optional[Number], OmitIfNone] = None
maximum: Annotated[Optional[Number], OmitIfNone] = None
mean: Annotated[Optional[Number], OmitIfNone] = None
stddev: Annotated[Optional[Number], OmitIfNone] = None
count: Annotated[Optional[int], OmitIfNone] = None
valid_percent: Annotated[Optional[Number], OmitIfNone] = None


NormalizeType: TypeAlias = Optional[Literal[
Expand Down Expand Up @@ -55,13 +50,13 @@ class Band(BaseModel):
]]


class ModelInput(BaseModel):
class ModelInput(MLMBaseModel):
name: str
bands: List[str] # order is critical here (same index as dim shape), allow duplicate if the model needs it somehow
input: InputArray
norm_by_channel: bool = None
norm_type: NormalizeType = None
norm_clip: Optional[List[Union[float, int]]] = None
resize_type: ResizeType = None
statistics: Optional[Union[Statistics, List[Statistics]]] = None
input: InputStructure
norm_by_channel: Annotated[bool, OmitIfNone] = None
norm_type: Annotated[NormalizeType, OmitIfNone] = None
norm_clip: Annotated[List[Union[float, int]], OmitIfNone] = None
resize_type: Annotated[ResizeType, OmitIfNone] = None
statistics: Annotated[List[MLMStatistic], OmitIfNone] = None
pre_processing_function: Optional[ProcessingExpression] = None
16 changes: 7 additions & 9 deletions stac_model/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing_extensions import NotRequired, TypedDict

from pystac.extensions.classification import Classification
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, PlainSerializer, model_serializer
from pydantic import AliasChoices, ConfigDict, Field, PlainSerializer, model_serializer

from stac_model.base import DataType, ModelTask, ProcessingExpression
from stac_model.base import DataType, MLMBaseModel, ModelTask, ProcessingExpression, OmitIfNone


class ModelResult(BaseModel):
class ModelResult(MLMBaseModel):
shape: List[Union[int, float]] = Field(..., min_items=1)
dim_order: List[str] = Field(..., min_items=1)
data_type: DataType
Expand All @@ -31,7 +31,7 @@ class ModelResult(BaseModel):
# ]


class MLMClassification(BaseModel, Classification):
class MLMClassification(MLMBaseModel, Classification):
@model_serializer()
def model_dump(self, *_, **__) -> Dict[str, Any]:
return self.to_dict()
Expand Down Expand Up @@ -60,7 +60,7 @@ def __setattr__(self, key: str, value: Any) -> None:
if key == "properties":
Classification.__setattr__(self, key, value)
else:
BaseModel.__setattr__(self, key, value)
MLMBaseModel.__setattr__(self, key, value)

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand All @@ -73,7 +73,7 @@ def __setattr__(self, key: str, value: Any) -> None:
# nodata: Optional[bool] = False


class ModelOutput(BaseModel):
class ModelOutput(MLMBaseModel):
name: str
tasks: Set[ModelTask]
result: ModelResult
Expand All @@ -83,11 +83,9 @@ class ModelOutput(BaseModel):
# it is more important to keep the order in this case,
# which we would lose with 'Set'.
# We also get some unhashable errors with 'Set', although 'MLMClassification' implements '__hash__'.
classes: List[MLMClassification] = Field(
classes: Annotated[List[MLMClassification], OmitIfNone] = Field(
alias="classification:classes",
validation_alias=AliasChoices("classification:classes", "classification_classes"),
exclude_unset=True,
exclude_defaults=True
)
post_processing_function: Optional[ProcessingExpression] = None

Expand Down
Loading

0 comments on commit 4d57e41

Please sign in to comment.