diff --git a/seed/fastapi/trace/v_2/problem/types/create_problem_request_v_2.py b/seed/fastapi/trace/v_2/problem/types/create_problem_request_v_2.py index 21725bb52..7c307b1d4 100644 --- a/seed/fastapi/trace/v_2/problem/types/create_problem_request_v_2.py +++ b/seed/fastapi/trace/v_2/problem/types/create_problem_request_v_2.py @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel): custom_files: CustomFiles = pydantic.Field(alias="customFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") testcases: typing.List[TestCaseV2] - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") is_public: bool = pydantic.Field(alias="isPublic") def json(self, **kwargs: typing.Any) -> str: diff --git a/seed/fastapi/trace/v_2/problem/types/problem_info_v_2.py b/seed/fastapi/trace/v_2/problem/types/problem_info_v_2.py index 162d2c1f8..e56985c4c 100644 --- a/seed/fastapi/trace/v_2/problem/types/problem_info_v_2.py +++ b/seed/fastapi/trace/v_2/problem/types/problem_info_v_2.py @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel): problem_description: ProblemDescription = pydantic.Field(alias="problemDescription") problem_name: str = pydantic.Field(alias="problemName") problem_version: int = pydantic.Field(alias="problemVersion") - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") custom_files: CustomFiles = pydantic.Field(alias="customFiles") generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") diff --git a/seed/fastapi/trace/v_2/v_3/problem/types/create_problem_request_v_2.py b/seed/fastapi/trace/v_2/v_3/problem/types/create_problem_request_v_2.py index ede8824fd..8a1944739 100644 --- a/seed/fastapi/trace/v_2/v_3/problem/types/create_problem_request_v_2.py +++ b/seed/fastapi/trace/v_2/v_3/problem/types/create_problem_request_v_2.py @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel): custom_files: CustomFiles = pydantic.Field(alias="customFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") testcases: typing.List[TestCaseV2] - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") is_public: bool = pydantic.Field(alias="isPublic") def json(self, **kwargs: typing.Any) -> str: diff --git a/seed/fastapi/trace/v_2/v_3/problem/types/problem_info_v_2.py b/seed/fastapi/trace/v_2/v_3/problem/types/problem_info_v_2.py index 5f32a3687..778367e20 100644 --- a/seed/fastapi/trace/v_2/v_3/problem/types/problem_info_v_2.py +++ b/seed/fastapi/trace/v_2/v_3/problem/types/problem_info_v_2.py @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel): problem_description: ProblemDescription = pydantic.Field(alias="problemDescription") problem_name: str = pydantic.Field(alias="problemName") problem_version: int = pydantic.Field(alias="problemVersion") - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") custom_files: CustomFiles = pydantic.Field(alias="customFiles") generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") diff --git a/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/problem/create_problem_request_v_2.py b/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/problem/create_problem_request_v_2.py index 8d8909815..a6f7a341e 100644 --- a/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/problem/create_problem_request_v_2.py +++ b/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/problem/create_problem_request_v_2.py @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel): custom_files: CustomFiles = pydantic.Field(alias="customFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") testcases: typing.List[TestCaseV2] - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") is_public: bool = pydantic.Field(alias="isPublic") def json(self, **kwargs: typing.Any) -> str: diff --git a/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/problem/problem_info_v_2.py b/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/problem/problem_info_v_2.py index 5ea740531..6a203a362 100644 --- a/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/problem/problem_info_v_2.py +++ b/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/problem/problem_info_v_2.py @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel): problem_description: ProblemDescription = pydantic.Field(alias="problemDescription") problem_name: str = pydantic.Field(alias="problemName") problem_version: int = pydantic.Field(alias="problemVersion") - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") custom_files: CustomFiles = pydantic.Field(alias="customFiles") generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") diff --git a/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/v_3/resources/problem/create_problem_request_v_2.py b/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/v_3/resources/problem/create_problem_request_v_2.py index 25fb85949..c062a7b9b 100644 --- a/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/v_3/resources/problem/create_problem_request_v_2.py +++ b/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/v_3/resources/problem/create_problem_request_v_2.py @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel): custom_files: CustomFiles = pydantic.Field(alias="customFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") testcases: typing.List[TestCaseV2] - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") is_public: bool = pydantic.Field(alias="isPublic") def json(self, **kwargs: typing.Any) -> str: diff --git a/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/v_3/resources/problem/problem_info_v_2.py b/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/v_3/resources/problem/problem_info_v_2.py index afb3fefea..6cddb8b64 100644 --- a/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/v_3/resources/problem/problem_info_v_2.py +++ b/seed/pydantic/trace/src/seed/trace/resources/v_2/resources/v_3/resources/problem/problem_info_v_2.py @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel): problem_description: ProblemDescription = pydantic.Field(alias="problemDescription") problem_name: str = pydantic.Field(alias="problemName") problem_version: int = pydantic.Field(alias="problemVersion") - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") custom_files: CustomFiles = pydantic.Field(alias="customFiles") generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") diff --git a/seed/sdk/trace/src/seed/v_2/problem/create_problem_request_v_2.py b/seed/sdk/trace/src/seed/v_2/problem/create_problem_request_v_2.py index bd555a3ca..5a66228cb 100644 --- a/seed/sdk/trace/src/seed/v_2/problem/create_problem_request_v_2.py +++ b/seed/sdk/trace/src/seed/v_2/problem/create_problem_request_v_2.py @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel): custom_files: CustomFiles = pydantic.Field(alias="customFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") testcases: typing.List[TestCaseV2] - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") is_public: bool = pydantic.Field(alias="isPublic") def json(self, **kwargs: typing.Any) -> str: diff --git a/seed/sdk/trace/src/seed/v_2/problem/problem_info_v_2.py b/seed/sdk/trace/src/seed/v_2/problem/problem_info_v_2.py index aa1204d34..16732347c 100644 --- a/seed/sdk/trace/src/seed/v_2/problem/problem_info_v_2.py +++ b/seed/sdk/trace/src/seed/v_2/problem/problem_info_v_2.py @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel): problem_description: ProblemDescription = pydantic.Field(alias="problemDescription") problem_name: str = pydantic.Field(alias="problemName") problem_version: int = pydantic.Field(alias="problemVersion") - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") custom_files: CustomFiles = pydantic.Field(alias="customFiles") generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") diff --git a/seed/sdk/trace/src/seed/v_2/v_3/problem/create_problem_request_v_2.py b/seed/sdk/trace/src/seed/v_2/v_3/problem/create_problem_request_v_2.py index 50d9516f3..c91ba470f 100644 --- a/seed/sdk/trace/src/seed/v_2/v_3/problem/create_problem_request_v_2.py +++ b/seed/sdk/trace/src/seed/v_2/v_3/problem/create_problem_request_v_2.py @@ -22,7 +22,7 @@ class CreateProblemRequestV2(pydantic.BaseModel): custom_files: CustomFiles = pydantic.Field(alias="customFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") testcases: typing.List[TestCaseV2] - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") is_public: bool = pydantic.Field(alias="isPublic") def json(self, **kwargs: typing.Any) -> str: diff --git a/seed/sdk/trace/src/seed/v_2/v_3/problem/problem_info_v_2.py b/seed/sdk/trace/src/seed/v_2/v_3/problem/problem_info_v_2.py index 29ea183a1..50cbbbfc8 100644 --- a/seed/sdk/trace/src/seed/v_2/v_3/problem/problem_info_v_2.py +++ b/seed/sdk/trace/src/seed/v_2/v_3/problem/problem_info_v_2.py @@ -23,7 +23,7 @@ class ProblemInfoV2(pydantic.BaseModel): problem_description: ProblemDescription = pydantic.Field(alias="problemDescription") problem_name: str = pydantic.Field(alias="problemName") problem_version: int = pydantic.Field(alias="problemVersion") - supported_languages: typing.List[Language] = pydantic.Field(alias="supportedLanguages") + supported_languages: typing.Set[Language] = pydantic.Field(alias="supportedLanguages") custom_files: CustomFiles = pydantic.Field(alias="customFiles") generated_files: GeneratedFiles = pydantic.Field(alias="generatedFiles") custom_test_case_templates: typing.List[TestCaseTemplate] = pydantic.Field(alias="customTestCaseTemplates") diff --git a/src/fern_python/generators/context/pydantic_generator_context_impl.py b/src/fern_python/generators/context/pydantic_generator_context_impl.py index e743aa100..7849e6f67 100644 --- a/src/fern_python/generators/context/pydantic_generator_context_impl.py +++ b/src/fern_python/generators/context/pydantic_generator_context_impl.py @@ -20,7 +20,7 @@ def __init__( ): super().__init__(ir=ir, generator_config=generator_config) self._type_reference_to_type_hint_converter = TypeReferenceToTypeHintConverter( - type_declaration_referencer=type_declaration_referencer, + type_declaration_referencer=type_declaration_referencer, context=self ) self._type_declaration_referencer = type_declaration_referencer self._project_module_path = project_module_path diff --git a/src/fern_python/generators/context/type_reference_to_type_hint_converter.py b/src/fern_python/generators/context/type_reference_to_type_hint_converter.py index 1a33d0cd9..892cceb2e 100644 --- a/src/fern_python/generators/context/type_reference_to_type_hint_converter.py +++ b/src/fern_python/generators/context/type_reference_to_type_hint_converter.py @@ -5,9 +5,16 @@ from fern_python.codegen import AST from fern_python.declaration_referencer import AbstractDeclarationReferencer +from .pydantic_generator_context import PydanticGeneratorContext + class TypeReferenceToTypeHintConverter: - def __init__(self, type_declaration_referencer: AbstractDeclarationReferencer[ir_types.DeclaredTypeName]): + def __init__( + self, + type_declaration_referencer: AbstractDeclarationReferencer[ir_types.DeclaredTypeName], + context: PydanticGeneratorContext, + ): + self._context = context self._type_declaration_referencer = type_declaration_referencer def get_type_hint_for_type_reference( @@ -28,6 +35,28 @@ def get_type_hint_for_type_reference( unknown=AST.TypeHint.any, ) + def _get_set_type_hint_for_named( + self, + name: ir_types.DeclaredTypeName, + must_import_after_current_declaration: Optional[Callable[[ir_types.DeclaredTypeName], bool]], + ) -> AST.TypeHint: + is_primative = self._context.get_declaration_for_type_id(name.type_id).shape.visit( + alias=lambda alias_td: alias_td.resolved_type.visit( + container=lambda c: False, named=lambda n: False, primitive=lambda p: True, unknown=lambda: False + ), + enum=lambda enum_td: True, + object=lambda object_td: False, + union=lambda union_td: False, + undiscriminated_union=lambda union_td: False, + ) + inner_hint = self._get_type_hint_for_named( + type_name=name, + must_import_after_current_declaration=must_import_after_current_declaration, + ) + if is_primative: + return AST.TypeHint.set(inner_hint) + return AST.TypeHint.list(inner_hint) + def _get_type_hint_for_container( self, container: ir_types.ContainerType, @@ -58,16 +87,12 @@ def _get_type_hint_for_container( must_import_after_current_declaration=must_import_after_current_declaration, ) ), - named=lambda type_reference: AST.TypeHint.list( - self._get_type_hint_for_named( - type_name=type_reference, - must_import_after_current_declaration=must_import_after_current_declaration, - ) + named=lambda type_reference: self._get_set_type_hint_for_named( + type_reference, + must_import_after_current_declaration=must_import_after_current_declaration, ), primitive=lambda type_reference: AST.TypeHint.set( - self._get_type_hint_for_primitive( - primitive=type_reference, - ) + self._get_type_hint_for_primitive(primitive=type_reference) ), unknown=lambda: AST.TypeHint.list(AST.TypeHint.any()), ),