Skip to content

Commit

Permalink
move back to attrs (#729)
Browse files Browse the repository at this point in the history
* move back to attrs

* update changelog

* edit tests

* more doc
  • Loading branch information
vincentsarago authored Jul 9, 2024
1 parent 494e485 commit 0885f0b
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 157 deletions.
9 changes: 8 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## [Unreleased] - TBD

## [3.0.0b2] - 2024-07-09

### Changed

* move back to `@attrs` (instead of dataclass) for `APIRequest` (model for GET request) class type [#729](https://github.com/stac-utils/stac-fastapi/pull/729)

## [3.0.0b1] - 2024-07-05

### Added
Expand Down Expand Up @@ -432,7 +438,8 @@

* First PyPi release!

[Unreleased]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0b1..main>
[Unreleased]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0b2..main>
[3.0.0b2]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0b1..3.0.0b2>
[3.0.0b1]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0a4..3.0.0b1>
[3.0.0a4]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0a3..3.0.0a4>
[3.0.0a3]: <https://github.com/stac-utils/stac-fastapi/compare/3.0.0a2..3.0.0a3>
Expand Down
78 changes: 33 additions & 45 deletions docs/src/migrations/v3.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,49 +23,6 @@ In addition to pydantic v2 update, `stac-pydantic` has been updated to better ma

* `PostFieldsExtension.filter_fields` property has been removed.

## `attr` -> `dataclass` for APIRequest models

Models for **GET** requests, defining the path and query parameters, now uses python `dataclass` instead of `attr`.

```python
# before
@attr.s
class CollectionModel(APIRequest):
collections: Optional[str] = attr.ib(default=None, converter=str2list)

# now
@dataclass
class CollectionModel(APIRequest):
collections: Annotated[Optional[str], Query()] = None

def __post_init__(self):
"""convert attributes."""
if self.collections:
self.collections = str2list(self.collections) # type: ignore

```

!!! warning

if you want to extend a class with a `required` attribute (without default), you will have to write all the attributes to avoid having *non-default* attributes defined after *default* attributes (ref: https://github.com/stac-utils/stac-fastapi/pull/714/files#r1651557338)

```python
@dataclass
class A:
value: Annotated[str, Query()]

# THIS WON'T WORK
@dataclass
class B(A):
another_value: Annotated[str, Query(...)]

# DO THIS
@dataclass
class B(A):
another_value: Annotated[str, Query(...)]
value: Annotated[str, Query()]
```

## Middlewares configuration

The `StacApi.middlewares` attribute has been updated to accept a list of `starlette.middleware.Middleware`. This enables dynamic configuration of middlewares (see https://github.com/stac-utils/stac-fastapi/pull/442).
Expand Down Expand Up @@ -113,9 +70,9 @@ stac = StacApi(
)

# now
@dataclass
@attr.s
class CollectionsRequest(APIRequest):
user: str = Query(...)
user: Annotated[str, Query(...)] = attr.ib()

stac = StacApi(
search_get_request_model=getSearchModel,
Expand All @@ -127,6 +84,37 @@ stac = StacApi(
)
```

## APIRequest - GET Request Model

Most of the **GET** endpoints are configured with `stac_fastapi.types.search.APIRequest` base class.

e.g the BaseSearchGetRequest, default for the `GET - /search` endpoint:

```python
@attr.s
class BaseSearchGetRequest(APIRequest):
"""Base arguments for GET Request."""

collections: Annotated[Optional[str], Query()] = attr.ib(
default=None, converter=str2list
)
ids: Annotated[Optional[str], Query()] = attr.ib(default=None, converter=str2list)
bbox: Annotated[Optional[BBox], Query()] = attr.ib(default=None, converter=str2bbox)
intersects: Annotated[Optional[str], Query()] = attr.ib(default=None)
datetime: Annotated[Optional[DateTimeType], Query()] = attr.ib(
default=None, converter=str_to_interval
)
limit: Annotated[Optional[int], Query()] = attr.ib(default=10)
```

We use [*python attrs*](https://www.attrs.org/en/stable/) to construct those classes. **Type Hint** for each attribute is important and should be defined using `Annotated[{type}, fastapi.Query()]` form.

```python
@attr.s
class SomeRequest(APIRequest):
user_number: Annotated[Optional[int], Query(alias="user-number")] = attr.ib(default=None)
```

## Filter extension

`default_includes` attribute has been removed from the `ApiSettings` object. If you need `defaults` includes you can overwrite the `FieldExtension` models (see https://github.com/stac-utils/stac-fastapi/pull/706).
Expand Down
37 changes: 16 additions & 21 deletions stac_fastapi/api/stac_fastapi/api/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Api request/response models."""

from dataclasses import dataclass, make_dataclass
from typing import List, Optional, Type, Union

import attr
from fastapi import Path, Query
from pydantic import BaseModel, create_model
from stac_pydantic.shared import BBox
Expand Down Expand Up @@ -43,11 +43,11 @@ def create_request_model(

mixins = mixins or []

models = extension_models + mixins + [base_model]
models = [base_model] + extension_models + mixins

# Handle GET requests
if all([issubclass(m, APIRequest) for m in models]):
return make_dataclass(model_name, [], bases=tuple(models))
return attr.make_class(model_name, attrs={}, bases=tuple(models))

# Handle POST requests
elif all([issubclass(m, BaseModel) for m in models]):
Expand Down Expand Up @@ -86,43 +86,38 @@ def create_post_request_model(
)


@dataclass
@attr.s
class CollectionUri(APIRequest):
"""Get or delete collection."""

collection_id: Annotated[str, Path(description="Collection ID")]
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()


@dataclass
@attr.s
class ItemUri(APIRequest):
"""Get or delete item."""

collection_id: Annotated[str, Path(description="Collection ID")]
item_id: Annotated[str, Path(description="Item ID")]
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
item_id: Annotated[str, Path(description="Item ID")] = attr.ib()


@dataclass
@attr.s
class EmptyRequest(APIRequest):
"""Empty request."""

...


@dataclass
@attr.s
class ItemCollectionUri(APIRequest):
"""Get item collection."""

collection_id: Annotated[str, Path(description="Collection ID")]
limit: Annotated[int, Query()] = 10
bbox: Annotated[Optional[BBox], Query()] = None
datetime: Annotated[Optional[DateTimeType], Query()] = None

def __post_init__(self):
"""convert attributes."""
if self.bbox:
self.bbox = str2bbox(self.bbox) # type: ignore
if self.datetime:
self.datetime = str_to_interval(self.datetime) # type: ignore
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
limit: Annotated[int, Query()] = attr.ib(default=10)
bbox: Annotated[Optional[BBox], Query()] = attr.ib(default=None, converter=str2bbox)
datetime: Annotated[Optional[DateTimeType], Query()] = attr.ib(
default=None, converter=str_to_interval
)


class GeoJSONResponse(JSONResponse):
Expand Down
27 changes: 14 additions & 13 deletions stac_fastapi/api/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional, Union

import attr
import pytest
from fastapi import Path, Query
from fastapi.testclient import TestClient
from pydantic import ValidationError
from stac_pydantic import api
from typing_extensions import Annotated

from stac_fastapi.api import app
from stac_fastapi.api.models import (
Expand Down Expand Up @@ -328,25 +329,25 @@ def item_collection(
def test_request_model(AsyncTestCoreClient):
"""Test if request models are passed correctly."""

@dataclass
@attr.s
class CollectionsRequest(APIRequest):
user: str = Query(...)
user: Annotated[str, Query(...)] = attr.ib()

@dataclass
@attr.s
class CollectionRequest(APIRequest):
collection_id: str = Path(description="Collection ID")
user: str = Query(...)
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
user: Annotated[str, Query(...)] = attr.ib()

@dataclass
@attr.s
class ItemsRequest(APIRequest):
collection_id: str = Path(description="Collection ID")
user: str = Query(...)
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
user: Annotated[str, Query(...)] = attr.ib()

@dataclass
@attr.s
class ItemRequest(APIRequest):
collection_id: str = Path(description="Collection ID")
item_id: str = Path(description="Item ID")
user: str = Query(...)
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
item_id: Annotated[str, Path(description="Item ID")] = attr.ib()
user: Annotated[str, Query(...)] = attr.ib()

test_app = app.StacApi(
settings=ApiSettings(),
Expand Down
26 changes: 17 additions & 9 deletions stac_fastapi/api/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import json

import pytest
from fastapi import Depends, FastAPI
from fastapi import Depends, FastAPI, HTTPException
from fastapi.testclient import TestClient
from pydantic import ValidationError

from stac_fastapi.api.models import create_get_request_model, create_post_request_model
from stac_fastapi.extensions.core.filter.filter import FilterExtension
from stac_fastapi.extensions.core.sort.sort import SortExtension
from stac_fastapi.extensions.core import FieldsExtension, FilterExtension, SortExtension
from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest


def test_create_get_request_model():
extensions = [FilterExtension()]
request_model = create_get_request_model(extensions, BaseSearchGetRequest)
request_model = create_get_request_model(
extensions=[FilterExtension(), FieldsExtension()],
base_model=BaseSearchGetRequest,
)

model = request_model(
collections="test1,test2",
Expand All @@ -35,6 +36,9 @@ def test_create_get_request_model():
assert model.collections == ["test1", "test2"]
assert model.filter_crs == "epsg:4326"

with pytest.raises(HTTPException):
request_model(datetime="yo")

app = FastAPI()

@app.get("/test")
Expand Down Expand Up @@ -62,8 +66,10 @@ def route(model=Depends(request_model)):
[(None, True), ({"test": "test"}, True), ("test==test", False), ([], False)],
)
def test_create_post_request_model(filter, passes):
extensions = [FilterExtension()]
request_model = create_post_request_model(extensions, BaseSearchPostRequest)
request_model = create_post_request_model(
extensions=[FilterExtension(), FieldsExtension()],
base_model=BaseSearchPostRequest,
)

if not passes:
with pytest.raises(ValidationError):
Expand Down Expand Up @@ -100,8 +106,10 @@ def test_create_post_request_model(filter, passes):
],
)
def test_create_post_request_model_nested_fields(sortby, passes):
extensions = [SortExtension()]
request_model = create_post_request_model(extensions, BaseSearchPostRequest)
request_model = create_post_request_model(
extensions=[SortExtension()],
base_model=BaseSearchPostRequest,
)

if not passes:
with pytest.raises(ValidationError):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Request model for the Aggregation extension."""

from dataclasses import dataclass
from typing import List, Optional

import attr
from fastapi import Query
from pydantic import Field
from typing_extensions import Annotated
Expand All @@ -14,17 +14,13 @@
)


@dataclass
@attr.s
class AggregationExtensionGetRequest(BaseSearchGetRequest):
"""Aggregation Extension GET request model."""

aggregations: Annotated[Optional[str], Query()] = None

def __post_init__(self):
"""convert attributes."""
super().__post_init__()
if self.aggregations:
self.aggregations = str2list(self.aggregations) # type: ignore
aggregations: Annotated[Optional[str], Query()] = attr.ib(
default=None, converter=str2list
)


class AggregationExtensionPostRequest(BaseSearchPostRequest):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Request models for the fields extension."""

import warnings
from dataclasses import dataclass
from typing import Dict, Optional, Set

import attr
from fastapi import Query
from pydantic import BaseModel, Field
from typing_extensions import Annotated
Expand Down Expand Up @@ -70,16 +70,11 @@ def filter_fields(self) -> Dict:
}


@dataclass
@attr.s
class FieldsExtensionGetRequest(APIRequest):
"""Additional fields for the GET request."""

fields: Annotated[Optional[str], Query()] = None

def __post_init__(self):
"""convert attributes."""
if self.fields:
self.fields = str2list(self.fields) # type: ignore
fields: Annotated[Optional[str], Query()] = attr.ib(default=None, converter=str2list)


class FieldsExtensionPostRequest(BaseModel):
Expand Down
Loading

0 comments on commit 0885f0b

Please sign in to comment.