diff --git a/llama-index-core/llama_index/core/indices/property_graph/transformations/schema_llm.py b/llama-index-core/llama_index/core/indices/property_graph/transformations/schema_llm.py index c5c9ba2eead9d..f9a34c772985b 100644 --- a/llama-index-core/llama_index/core/indices/property_graph/transformations/schema_llm.py +++ b/llama-index-core/llama_index/core/indices/property_graph/transformations/schema_llm.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union try: from typing import TypeAlias @@ -47,62 +47,37 @@ "HAS_ALIAS", ] -# Which entities can be connected to which relations -DEFAULT_VALIDATION_SCHEMA: Dict[str, Any] = { - "PRODUCT": ( - "USED_BY", - "USED_FOR", - "LOCATED_IN", - "PART_OF", - "WORKED_ON", - "HAS", - "IS_A", - ), - "MARKET": ("LOCATED_IN", "PART_OF", "WORKED_ON", "HAS", "IS_A"), - "TECHNOLOGY": ( - "USED_BY", - "USED_FOR", - "LOCATED_IN", - "PART_OF", - "WORKED_ON", - "HAS", - "IS_A", - ), - "EVENT": ("LOCATED_IN", "PART_OF", "WORKED_ON", "HAS", "IS_A"), - "CONCEPT": ("USE_BY", "USED_FOR", "PART_OF", "WORKED_ON", "HAS", "IS_A"), - "ORGANIZATION": ("LOCATED_IN", "PART_OF", "HAS", "IS_A"), - "PERSON": ( - "BORN_IN", - "DIED_IN", - "LOCATED_IN", - "PART_OF", - "WORKED_ON", - "HAS", - "IS_A", - ), - "LOCATION": ( - "LOCATED_IN", - "PART_OF", - "HAS", - "IS_A", - "DIED_IN", - "BORN_IN", - "USED_BY", - "USED_FOR", - ), - "TIME": ("BORN_IN", "DIED_IN", "LOCATED_IN", "PART_OF", "HAS", "IS_A"), - "MISCELLANEOUS": ( - "USED_BY", - "USED_FOR", - "LOCATED_IN", - "PART_OF", - "WORKED_ON", - "HAS", - "IS_A", - "BORN_IN", - "DIED_IN", - ), -} +# Convert the above dict schema into a list of triples +Triple = Tuple[str, str, str] +DEFAULT_VALIDATION_SCHEMA: List[Triple] = [ + ("PRODUCT", "USED_BY", "PRODUCT"), + ("PRODUCT", "USED_FOR", "MARKET"), + ("PRODUCT", "HAS", "TECHNOLOGY"), + ("MARKET", "LOCATED_IN", "LOCATION"), + ("MARKET", "HAS", "TECHNOLOGY"), + ("TECHNOLOGY", "USED_BY", "PRODUCT"), + ("TECHNOLOGY", "USED_FOR", "MARKET"), + ("TECHNOLOGY", "LOCATED_IN", "LOCATION"), + ("TECHNOLOGY", "PART_OF", "ORGANIZATION"), + ("TECHNOLOGY", "IS_A", "PRODUCT"), + ("EVENT", "LOCATED_IN", "LOCATION"), + ("EVENT", "PART_OF", "ORGANIZATION"), + ("CONCEPT", "USED_BY", "TECHNOLOGY"), + ("CONCEPT", "USED_FOR", "PRODUCT"), + ("ORGANIZATION", "LOCATED_IN", "LOCATION"), + ("ORGANIZATION", "PART_OF", "ORGANIZATION"), + ("ORGANIZATION", "PART_OF", "MARKET"), + ("PERSON", "BORN_IN", "LOCATION"), + ("PERSON", "BORN_IN", "TIME"), + ("PERSON", "DIED_IN", "LOCATION"), + ("PERSON", "DIED_IN", "TIME"), + ("PERSON", "WORKED_ON", "EVENT"), + ("PERSON", "WORKED_ON", "PRODUCT"), + ("PERSON", "WORKED_ON", "CONCEPT"), + ("PERSON", "WORKED_ON", "TECHNOLOGY"), + ("LOCATION", "LOCATED_IN", "LOCATION"), + ("LOCATION", "PART_OF", "LOCATION"), +] DEFAULT_SCHEMA_PATH_EXTRACT_PROMPT = PromptTemplate( "Give the following text, extract the knowledge graph according to the provided schema. " @@ -114,7 +89,8 @@ class SchemaLLMPathExtractor(TransformComponent): - """Extract paths from a graph using a schema. + """ + Extract paths from a graph using a schema. Args: llm (LLM): @@ -154,7 +130,7 @@ def __init__( possible_relations: Optional[TypeAlias] = None, strict: bool = True, kg_schema_cls: Any = None, - kg_validation_schema: Dict[str, str] = None, + kg_validation_schema: Union[Dict[str, str], List[Triple]] = None, max_triplets_per_chunk: int = 10, num_workers: int = 4, ) -> None: @@ -230,11 +206,17 @@ def validate(v: Any, values: Any) -> Any: ) kg_schema_cls.__doc__ = "Knowledge Graph Schema." + # Get validation schema + kg_validation_schema = kg_validation_schema or DEFAULT_VALIDATION_SCHEMA + # TODO: Remove this in a future version & encourage List[Triple] for validation schema + if isinstance(kg_validation_schema, list): + kg_validation_schema = {"relationships": kg_validation_schema} + super().__init__( llm=llm, extract_prompt=extract_prompt or DEFAULT_SCHEMA_PATH_EXTRACT_PROMPT, kg_schema_cls=kg_schema_cls, - kg_validation_schema=kg_validation_schema or DEFAULT_VALIDATION_SCHEMA, + kg_validation_schema=kg_validation_schema, num_workers=num_workers, max_triplets_per_chunk=max_triplets_per_chunk, strict=strict, @@ -264,13 +246,26 @@ def _prune_invalid_triplets(self, kg_schema: Any) -> List[Triplet]: obj = triplet.object.name obj_type = triplet.object.type - # check relations - if relation not in self.kg_validation_schema.get( - subject_type, [relation] - ) and relation not in self.kg_validation_schema.get(obj_type, [relation]): - continue - - # remove self-references + # Check if the triplet is valid based on the schema format + if ( + isinstance(self.kg_validation_schema, dict) + and "relationships" in self.kg_validation_schema + ): + # Schema is a dictionary with a 'relationships' key and triples as values + if (subject_type, relation, obj_type) not in self.kg_validation_schema[ + "relationships" + ]: + continue + else: + # Schema is the backwards-compat format + if relation not in self.kg_validation_schema.get( + subject_type, [relation] + ) and relation not in self.kg_validation_schema.get( + obj_type, [relation] + ): + continue + + # Remove self-references if subject.lower() == obj.lower(): continue