Skip to content

Commit

Permalink
replace list[list] with list
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Mar 15, 2024
1 parent d341fdb commit 7ed97e1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
42 changes: 31 additions & 11 deletions piccolo_api/fastapi/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from inspect import Parameter, Signature, isclass

from fastapi import APIRouter, FastAPI, Request, status
from fastapi.params import Body, Query
from fastapi.params import Query
from pydantic import BaseModel as PydanticBaseModel
from pydantic.main import BaseModel

Expand Down Expand Up @@ -122,7 +122,7 @@ def _get_type(type_: t.Type) -> t.Type:

def _is_multidimensional_array(type_: t.Type) -> bool:
"""
Returns ``True`` if ``_type`` is list[list].
Returns ``True`` if ``_type`` is ``list[list]``.
"""
if t.get_origin(type_) is list:
args = t.get_args(type_)
Expand All @@ -131,6 +131,29 @@ def _is_multidimensional_array(type_: t.Type) -> bool:
return False


def _get_array_base_type(type_: t.Type[t.List]) -> t.Type:
"""
Extracts the base type from an array. For example::
>>> _get_array_base_type(t.List[str])
str
>>> _get_array_base_type(t.List(t.List[str]))
str
"""
args = t.get_args(type_)
if args:
if t.get_origin(args[0]) is list:
return _get_array_base_type(args[0])
else:
return args[0]
return type_


foo = _get_array_base_type(t.List[str])


class FastAPIWrapper:
"""
Wraps ``PiccoloCRUD`` so it can easily be integrated into FastAPI.
Expand Down Expand Up @@ -473,20 +496,17 @@ def modify_signature(
assert annotation is not None
type_ = _get_type(annotation)

# Multidimensional arrays can't be used as query params - only
# body params. For now, we'll use a body param, so it doesn't
# crash, but will explore other options in the future (perhaps a
# string query param with a special query syntax for filtering
# multidimensional arrays).
param_class = (
Body if _is_multidimensional_array(type_=type_) else Query
)
# Multidimensional arrays can't be used as query params - e.g. if
# we have a column type of ``Array(Array(Integer()))``.
# For filtering purposes, we only need ``list[int]``.
if _is_multidimensional_array(type_=type_):
type_ = t.List[_get_array_base_type(type_=type_)]

parameters.append(
Parameter(
name=field_name,
kind=Parameter.POSITIONAL_OR_KEYWORD,
default=param_class(
default=Query(
default=None,
description=(f"Filter by the `{field_name}` column."),
),
Expand Down
17 changes: 16 additions & 1 deletion tests/fastapi/test_fastapi_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from starlette.testclient import TestClient

from piccolo_api.crud.endpoints import PiccoloCRUD
from piccolo_api.fastapi.endpoints import FastAPIWrapper, _get_type
from piccolo_api.fastapi.endpoints import (
FastAPIWrapper,
_get_array_base_type,
_get_type,
)


class Movie(Table):
Expand Down Expand Up @@ -282,3 +286,14 @@ def test_new_union_syntax(self):
"""
self.assertIs(_get_type(str | None), str) # type: ignore
self.assertIs(_get_type(None | str), str) # type: ignore


class TestGetArrayBaseType(TestCase):

def test_get_array_base_type(self):
"""
Make sure that `_get_array_base_type` returns the correct base type.
"""
self.assertIs(_get_array_base_type(t.List[str]), str)
self.assertIs(_get_array_base_type(t.List[t.List[str]]), str)
self.assertIs(_get_array_base_type(t.List[t.List[t.List[str]]]), str)

0 comments on commit 7ed97e1

Please sign in to comment.