Skip to content

Commit

Permalink
Update schema_llm (#14357)
Browse files Browse the repository at this point in the history
  • Loading branch information
prrao87 authored Jun 26, 2024
1 parent 1bbb2f6 commit 01da082
Showing 1 changed file with 62 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

try:
from typing import TypeAlias
Expand Down Expand Up @@ -47,62 +47,37 @@
"HAS_ALIAS",
]

# Which entities can be connected to which relations
DEFAULT_VALIDATION_SCHEMA: Dict[str, Any] = {
"PRODUCT": (
"USED_BY",
"USED_FOR",
"LOCATED_IN",
"PART_OF",
"WORKED_ON",
"HAS",
"IS_A",
),
"MARKET": ("LOCATED_IN", "PART_OF", "WORKED_ON", "HAS", "IS_A"),
"TECHNOLOGY": (
"USED_BY",
"USED_FOR",
"LOCATED_IN",
"PART_OF",
"WORKED_ON",
"HAS",
"IS_A",
),
"EVENT": ("LOCATED_IN", "PART_OF", "WORKED_ON", "HAS", "IS_A"),
"CONCEPT": ("USE_BY", "USED_FOR", "PART_OF", "WORKED_ON", "HAS", "IS_A"),
"ORGANIZATION": ("LOCATED_IN", "PART_OF", "HAS", "IS_A"),
"PERSON": (
"BORN_IN",
"DIED_IN",
"LOCATED_IN",
"PART_OF",
"WORKED_ON",
"HAS",
"IS_A",
),
"LOCATION": (
"LOCATED_IN",
"PART_OF",
"HAS",
"IS_A",
"DIED_IN",
"BORN_IN",
"USED_BY",
"USED_FOR",
),
"TIME": ("BORN_IN", "DIED_IN", "LOCATED_IN", "PART_OF", "HAS", "IS_A"),
"MISCELLANEOUS": (
"USED_BY",
"USED_FOR",
"LOCATED_IN",
"PART_OF",
"WORKED_ON",
"HAS",
"IS_A",
"BORN_IN",
"DIED_IN",
),
}
# Convert the above dict schema into a list of triples
Triple = Tuple[str, str, str]
DEFAULT_VALIDATION_SCHEMA: List[Triple] = [
("PRODUCT", "USED_BY", "PRODUCT"),
("PRODUCT", "USED_FOR", "MARKET"),
("PRODUCT", "HAS", "TECHNOLOGY"),
("MARKET", "LOCATED_IN", "LOCATION"),
("MARKET", "HAS", "TECHNOLOGY"),
("TECHNOLOGY", "USED_BY", "PRODUCT"),
("TECHNOLOGY", "USED_FOR", "MARKET"),
("TECHNOLOGY", "LOCATED_IN", "LOCATION"),
("TECHNOLOGY", "PART_OF", "ORGANIZATION"),
("TECHNOLOGY", "IS_A", "PRODUCT"),
("EVENT", "LOCATED_IN", "LOCATION"),
("EVENT", "PART_OF", "ORGANIZATION"),
("CONCEPT", "USED_BY", "TECHNOLOGY"),
("CONCEPT", "USED_FOR", "PRODUCT"),
("ORGANIZATION", "LOCATED_IN", "LOCATION"),
("ORGANIZATION", "PART_OF", "ORGANIZATION"),
("ORGANIZATION", "PART_OF", "MARKET"),
("PERSON", "BORN_IN", "LOCATION"),
("PERSON", "BORN_IN", "TIME"),
("PERSON", "DIED_IN", "LOCATION"),
("PERSON", "DIED_IN", "TIME"),
("PERSON", "WORKED_ON", "EVENT"),
("PERSON", "WORKED_ON", "PRODUCT"),
("PERSON", "WORKED_ON", "CONCEPT"),
("PERSON", "WORKED_ON", "TECHNOLOGY"),
("LOCATION", "LOCATED_IN", "LOCATION"),
("LOCATION", "PART_OF", "LOCATION"),
]

DEFAULT_SCHEMA_PATH_EXTRACT_PROMPT = PromptTemplate(
"Give the following text, extract the knowledge graph according to the provided schema. "
Expand All @@ -114,7 +89,8 @@


class SchemaLLMPathExtractor(TransformComponent):
"""Extract paths from a graph using a schema.
"""
Extract paths from a graph using a schema.
Args:
llm (LLM):
Expand Down Expand Up @@ -154,7 +130,7 @@ def __init__(
possible_relations: Optional[TypeAlias] = None,
strict: bool = True,
kg_schema_cls: Any = None,
kg_validation_schema: Dict[str, str] = None,
kg_validation_schema: Union[Dict[str, str], List[Triple]] = None,
max_triplets_per_chunk: int = 10,
num_workers: int = 4,
) -> None:
Expand Down Expand Up @@ -230,11 +206,17 @@ def validate(v: Any, values: Any) -> Any:
)
kg_schema_cls.__doc__ = "Knowledge Graph Schema."

# Get validation schema
kg_validation_schema = kg_validation_schema or DEFAULT_VALIDATION_SCHEMA
# TODO: Remove this in a future version & encourage List[Triple] for validation schema
if isinstance(kg_validation_schema, list):
kg_validation_schema = {"relationships": kg_validation_schema}

super().__init__(
llm=llm,
extract_prompt=extract_prompt or DEFAULT_SCHEMA_PATH_EXTRACT_PROMPT,
kg_schema_cls=kg_schema_cls,
kg_validation_schema=kg_validation_schema or DEFAULT_VALIDATION_SCHEMA,
kg_validation_schema=kg_validation_schema,
num_workers=num_workers,
max_triplets_per_chunk=max_triplets_per_chunk,
strict=strict,
Expand Down Expand Up @@ -264,13 +246,26 @@ def _prune_invalid_triplets(self, kg_schema: Any) -> List[Triplet]:
obj = triplet.object.name
obj_type = triplet.object.type

# check relations
if relation not in self.kg_validation_schema.get(
subject_type, [relation]
) and relation not in self.kg_validation_schema.get(obj_type, [relation]):
continue

# remove self-references
# Check if the triplet is valid based on the schema format
if (
isinstance(self.kg_validation_schema, dict)
and "relationships" in self.kg_validation_schema
):
# Schema is a dictionary with a 'relationships' key and triples as values
if (subject_type, relation, obj_type) not in self.kg_validation_schema[
"relationships"
]:
continue
else:
# Schema is the backwards-compat format
if relation not in self.kg_validation_schema.get(
subject_type, [relation]
) and relation not in self.kg_validation_schema.get(
obj_type, [relation]
):
continue

# Remove self-references
if subject.lower() == obj.lower():
continue

Expand Down

0 comments on commit 01da082

Please sign in to comment.