Skip to content
This repository has been archived by the owner on Feb 5, 2024. It is now read-only.

Commit

Permalink
[fix] also ensure that aliased primatives can become sets as well (#470)
Browse files Browse the repository at this point in the history
* Allow for making sets

* fix types in visitor

* seed tests

* run seed again but bumped v

* other seed tests

* add fastapi tests too...

* factor in the use of primative aliases

* fix primative aliases
  • Loading branch information
armandobelardo authored Dec 19, 2023
1 parent 0bcdcef commit ea09c46
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel):
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
testcases: typing.List[TestCaseV2]
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
is_public: bool = pydantic.Field(alias="isPublic")

def json(self, **kwargs: typing.Any) -> str:
Expand Down
2 changes: 1 addition & 1 deletion seed/fastapi/trace/v_2/problem/types/problem_info_v_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel):
problem_description: ProblemDescription = pydantic.Field(alias="problemDescription")
problem_name: str = pydantic.Field(alias="problemName")
problem_version: int = pydantic.Field(alias="problemVersion")
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel):
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
testcases: typing.List[TestCaseV2]
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
is_public: bool = pydantic.Field(alias="isPublic")

def json(self, **kwargs: typing.Any) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel):
problem_description: ProblemDescription = pydantic.Field(alias="problemDescription")
problem_name: str = pydantic.Field(alias="problemName")
problem_version: int = pydantic.Field(alias="problemVersion")
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel):
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
testcases: typing.List[TestCaseV2]
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
is_public: bool = pydantic.Field(alias="isPublic")

def json(self, **kwargs: typing.Any) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel):
problem_description: ProblemDescription = pydantic.Field(alias="problemDescription")
problem_name: str = pydantic.Field(alias="problemName")
problem_version: int = pydantic.Field(alias="problemVersion")
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel):
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
testcases: typing.List[TestCaseV2]
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
is_public: bool = pydantic.Field(alias="isPublic")

def json(self, **kwargs: typing.Any) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel):
problem_description: ProblemDescription = pydantic.Field(alias="problemDescription")
problem_name: str = pydantic.Field(alias="problemName")
problem_version: int = pydantic.Field(alias="problemVersion")
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel):
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
testcases: typing.List[TestCaseV2]
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
is_public: bool = pydantic.Field(alias="isPublic")

def json(self, **kwargs: typing.Any) -> str:
Expand Down
2 changes: 1 addition & 1 deletion seed/sdk/trace/src/seed/v_2/problem/problem_info_v_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel):
problem_description: ProblemDescription = pydantic.Field(alias="problemDescription")
problem_name: str = pydantic.Field(alias="problemName")
problem_version: int = pydantic.Field(alias="problemVersion")
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel):
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
testcases: typing.List[TestCaseV2]
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
is_public: bool = pydantic.Field(alias="isPublic")

def json(self, **kwargs: typing.Any) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel):
problem_description: ProblemDescription = pydantic.Field(alias="problemDescription")
problem_name: str = pydantic.Field(alias="problemName")
problem_version: int = pydantic.Field(alias="problemVersion")
supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages")
supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages")
custom_files: CustomFiles = pydantic.Field(alias="customFiles")
generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles")
custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
):
super().__init__(ir=ir, generator_config=generator_config)
self._type_reference_to_type_hint_converter = TypeReferenceToTypeHintConverter(
type_declaration_referencer=type_declaration_referencer,
type_declaration_referencer=type_declaration_referencer, context=self
)
self._type_declaration_referencer = type_declaration_referencer
self._project_module_path = project_module_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@
from fern_python.codegen import AST
from fern_python.declaration_referencer import AbstractDeclarationReferencer

from .pydantic_generator_context import PydanticGeneratorContext


class TypeReferenceToTypeHintConverter:
def __init__(self, type_declaration_referencer: AbstractDeclarationReferencer[ir_types.DeclaredTypeName]):
def __init__(
self,
type_declaration_referencer: AbstractDeclarationReferencer[ir_types.DeclaredTypeName],
context: PydanticGeneratorContext,
):
self._context = context
self._type_declaration_referencer = type_declaration_referencer

def get_type_hint_for_type_reference(
Expand All @@ -28,6 +35,28 @@ def get_type_hint_for_type_reference(
unknown=AST.TypeHint.any,
)

def _get_set_type_hint_for_named(
self,
name: ir_types.DeclaredTypeName,
must_import_after_current_declaration: Optional[Callable[[ir_types.DeclaredTypeName], bool]],
) -> AST.TypeHint:
is_primative = self._context.get_declaration_for_type_id(name.type_id).shape.visit(
alias=lambda alias_td: alias_td.resolved_type.visit(
container=lambda c: False, named=lambda n: False, primitive=lambda p: True, unknown=lambda: False
),
enum=lambda enum_td: True,
object=lambda object_td: False,
union=lambda union_td: False,
undiscriminated_union=lambda union_td: False,
)
inner_hint = self._get_type_hint_for_named(
type_name=name,
must_import_after_current_declaration=must_import_after_current_declaration,
)
if is_primative:
return AST.TypeHint.set(inner_hint)
return AST.TypeHint.list(inner_hint)

def _get_type_hint_for_container(
self,
container: ir_types.ContainerType,
Expand Down Expand Up @@ -58,16 +87,12 @@ def _get_type_hint_for_container(
must_import_after_current_declaration=must_import_after_current_declaration,
)
),
named=lambda type_reference: AST.TypeHint.list(
self._get_type_hint_for_named(
type_name=type_reference,
must_import_after_current_declaration=must_import_after_current_declaration,
)
named=lambda type_reference: self._get_set_type_hint_for_named(
type_reference,
must_import_after_current_declaration=must_import_after_current_declaration,
),
primitive=lambda type_reference: AST.TypeHint.set(
self._get_type_hint_for_primitive(
primitive=type_reference,
)
self._get_type_hint_for_primitive(primitive=type_reference)
),
unknown=lambda: AST.TypeHint.list(AST.TypeHint.any()),
),
Expand Down

0 comments on commit ea09c46

Please sign in to comment.