From 087e656e9008c661614638c99ed75e7b32b5ecfb Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 2 Sep 2024 15:25:55 -0400 Subject: [PATCH] v0.0.64 --- agixtsdk/__init__.py | 66 +++++++++++++++++++++++++++++--------------- setup.py | 2 +- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/agixtsdk/__init__.py b/agixtsdk/__init__.py index f5997ab..1fed30a 100644 --- a/agixtsdk/__init__.py +++ b/agixtsdk/__init__.py @@ -1678,7 +1678,7 @@ def plan_task( def _generate_detailed_schema(self, model: Type[BaseModel], depth: int = 0) -> str: """ Recursively generates a detailed schema representation of a Pydantic model, - including nested models. + including nested models and complex types. """ fields = model.__annotations__ field_descriptions = [] @@ -1687,29 +1687,55 @@ def _generate_detailed_schema(self, model: Type[BaseModel], depth: int = 0) -> s for field, field_type in fields.items(): description = f"{indent}{field}: " - if get_origin(field_type) == Union: - field_type = get_args(field_type)[0] + origin_type = get_origin(field_type) + if origin_type is None: + origin_type = field_type - if isinstance(field_type, type) and issubclass(field_type, BaseModel): - description += f"Nested Model:\n{self._generate_detailed_schema(field_type, depth + 1)}" - elif get_origin(field_type) == List: + if issubclass(origin_type, BaseModel): + description += f"Nested Model:\n{self._generate_detailed_schema(origin_type, depth + 1)}" + elif origin_type == List: list_type = get_args(field_type)[0] if isinstance(list_type, type) and issubclass(list_type, BaseModel): description += f"List of Nested Model:\n{self._generate_detailed_schema(list_type, depth + 1)}" + elif get_origin(list_type) == Union: + union_types = get_args(list_type) + description += f"List of Union:\n" + for union_type in union_types: + if issubclass(union_type, BaseModel): + description += f"{indent} - Nested Model:\n{self._generate_detailed_schema(union_type, depth + 2)}" + else: + description += f"{indent} - {union_type.__name__}\n" else: - description += f"List[{list_type.__name__}]" - elif get_origin(field_type) == Dict: + description += f"List[{self._get_type_name(list_type)}]" + elif origin_type == Dict: key_type, value_type = get_args(field_type) - description += f"Dict[{key_type.__name__}, {value_type.__name__}]" - elif isinstance(field_type, type) and issubclass(field_type, Enum): - enum_values = ", ".join([f"{e.name} = {e.value}" for e in field_type]) - description += f"{field_type.__name__} (Enum values: {enum_values})" + description += f"Dict[{self._get_type_name(key_type)}, {self._get_type_name(value_type)}]" + elif origin_type == Union: + union_types = get_args(field_type) + description += "Union of:\n" + for union_type in union_types: + if issubclass(union_type, BaseModel): + description += f"{indent} - Nested Model:\n{self._generate_detailed_schema(union_type, depth + 2)}" + else: + description += ( + f"{indent} - {self._get_type_name(union_type)}\n" + ) + elif issubclass(origin_type, Enum): + enum_values = ", ".join([f"{e.name} = {e.value}" for e in origin_type]) + description += f"{origin_type.__name__} (Enum values: {enum_values})" else: - description += f"{field_type.__name__}" + description += self._get_type_name(origin_type) field_descriptions.append(description) + return "\n".join(field_descriptions) + def _get_type_name(self, type_): + """Helper method to get the name of a type, handling some special cases.""" + if hasattr(type_, "__name__"): + return type_.__name__ + return str(type_).replace("typing.", "") + def convert_to_model( self, input_string: str, @@ -1732,12 +1758,12 @@ def convert_to_model( """ input_string = str(input_string) schema = self._generate_detailed_schema(model) - + if "user_input" in kwargs: del kwargs["user_input"] if "schema" in kwargs: del kwargs["schema"] - + response = self.prompt_agent( agent_name=agent_name, prompt_name="Convert to Model", @@ -1747,12 +1773,12 @@ def convert_to_model( **kwargs, }, ) - + if "```json" in response: response = response.split("```json")[1].split("```")[0].strip() elif "```" in response: response = response.split("```")[1].strip() - + try: response = json.loads(response) if response_type == "json": @@ -1766,11 +1792,7 @@ def convert_to_model( f"Error: {e} . Failed to convert the response to the model after {max_failures} attempts. Response: {response}" ) self.failures = 0 - return ( - response - if response - else "Failed to convert the response to the model." - ) + return response if response else "Failed to convert the response to the model." else: self.failures = 1 print( diff --git a/setup.py b/setup.py index e2d3f1f..09fc584 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="agixtsdk", - version="0.0.63", + version="0.0.64", description="The AGiXT SDK for Python.", long_description=long_description, long_description_content_type="text/markdown",