Skip to content

Commit

Permalink
Dev (#687)
Browse files Browse the repository at this point in the history
* Feature/iamteapot (#683)

* i am a teapot

* i am a teapot

* Feature/cleanup kg extraction (#685)

* Update pyproject.toml (#684)

* tweak local kg extr

* update

* update

* cleanups

* fix compose

* fix compose

* fix non-local

* up

* cleanup
  • Loading branch information
emrgnt-cmplxty authored Jul 12, 2024
1 parent 62d92a7 commit c61cf66
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 101 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "r2r"
version = "0.2.58"
version = "0.2.59"
description = "SciPhi R2R"
authors = ["Owen Colegrove <[email protected]>"]
license = "MIT"
Expand Down
3 changes: 2 additions & 1 deletion r2r/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .pipes.base_pipe import AsyncPipe, AsyncState, PipeType
from .providers.embedding_provider import EmbeddingConfig, EmbeddingProvider
from .providers.eval_provider import EvalConfig, EvalProvider
from .providers.kg_provider import KGConfig, KGProvider
from .providers.kg_provider import KGConfig, KGProvider, update_kg_prompt
from .providers.llm_provider import LLMConfig, LLMProvider
from .providers.prompt_provider import PromptConfig, PromptProvider
from .providers.vector_db_provider import VectorDBConfig, VectorDBProvider
Expand Down Expand Up @@ -143,6 +143,7 @@
"VectorDBProvider",
"KGProvider",
"KGConfig",
"update_kg_prompt",
# Other
"FilterCriteria",
"TextSplitter",
Expand Down
74 changes: 65 additions & 9 deletions r2r/base/providers/kg_provider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""Base classes for knowledge graph providers."""

import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple

from ..abstractions.llama_abstractions import (
EntityNode,
LabelledNode,
Relation,
VectorStoreQuery,
)
from typing import TYPE_CHECKING, Any, Optional, Tuple

from .prompt_provider import PromptProvider

if TYPE_CHECKING:
from r2r.main import R2RClient

from ...base.utils.base_utils import EntityType, Relation
from ..abstractions.llama_abstractions import EntityNode, LabelledNode
from ..abstractions.llama_abstractions import Relation as LlamaRelation
from ..abstractions.llama_abstractions import VectorStoreQuery
from ..abstractions.llm import GenerationConfig
from .base_provider import ProviderConfig

Expand Down Expand Up @@ -76,7 +80,7 @@ def upsert_nodes(self, nodes: list[EntityNode]) -> None:
pass

@abstractmethod
def upsert_relations(self, relations: list[Relation]) -> None:
def upsert_relations(self, relations: list[LlamaRelation]) -> None:
"""Abstract method to add triplet."""
pass

Expand Down Expand Up @@ -124,3 +128,55 @@ def update_kg_agent_prompt(
):
"""Abstract method to update the KG agent prompt."""
pass


def escape_braces(s: str) -> str:
"""
Escape braces in a string.
This is a placeholder function - implement the actual logic as needed.
"""
# Implement your escape_braces logic here
return s.replace("{", "{{").replace("}", "}}")


# TODO - Make this more configurable / intelligent
def update_kg_prompt(
client: "R2RClient",
r2r_prompts: PromptProvider,
prompt_base: str,
entity_types: list[EntityType],
relations: list[Relation],
) -> None:
# Get the default extraction template
template_name: str = f"{prompt_base}_with_spec"

new_template: str = r2r_prompts.get_prompt(
template_name,
{
"entity_types": json.dumps(
{
"entity_types": [
str(entity.name) for entity in entity_types
]
},
indent=4,
),
"relations": json.dumps(
{"predicates": [str(relation.name) for relation in relations]},
indent=4,
),
"input": """\n{input}""",
},
)

# Escape all braces in the template, except for the {input} placeholder, for formatting
escaped_template: str = escape_braces(new_template).replace(
"""{{input}}""", """{input}"""
)

# Update the client's prompt
client.update_prompt(
prompt_base,
template=escaped_template,
input_types={"input": "str"},
)
6 changes: 3 additions & 3 deletions r2r/base/utils/base_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional
from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable

if TYPE_CHECKING:
from ..pipeline.base_pipeline import AsyncPipeline
Expand Down Expand Up @@ -49,14 +49,14 @@ def __init__(self, name: str):
self.name = name


def format_entity_types(entity_types: List[EntityType]) -> str:
def format_entity_types(entity_types: list[EntityType]) -> str:
lines = []
for entity in entity_types:
lines.append(entity.name)
return "\n".join(lines)


def format_relations(predicates: List[Relation]) -> str:
def format_relations(predicates: list[Relation]) -> str:
lines = []
for predicate in predicates:
lines.append(predicate.name)
Expand Down
9 changes: 6 additions & 3 deletions r2r/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import uuid

import click
import requests
from dotenv import load_dotenv

from r2r.main.execution import R2RExecutionWrapper
Expand Down Expand Up @@ -133,8 +132,9 @@ def serve(obj, host, port, docker, docker_ext_neo4j, project_name):
is_flag=True,
help="Remove containers for services not defined in the Compose file",
)
@click.option("--project-name", default="r2r", help="Project name for Docker")
@click.pass_context
def docker_down(ctx, volumes, remove_orphans):
def docker_down(ctx, volumes, remove_orphans, project_name):
"""Bring down the Docker Compose setup and attempt to remove the network if necessary."""
package_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", ".."
Expand All @@ -151,15 +151,18 @@ def docker_down(ctx, volumes, remove_orphans):
return

docker_command = (
f"docker-compose -f {compose_yaml} -f {compose_neo4j_yaml} down"
f"docker-compose -f {compose_yaml} -f {compose_neo4j_yaml}"
)
docker_command += f" --project-name {project_name}"

if volumes:
docker_command += " --volumes"

if remove_orphans:
docker_command += " --remove-orphans"

docker_command += " down"

click.echo("Bringing down Docker Compose setup...")
result = os.system(docker_command)

Expand Down
6 changes: 4 additions & 2 deletions r2r/examples/configs/local_neo4j_kg.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
"batch_size": 1,
"text_splitter": {
"type": "recursive_character",
"chunk_size": 1024,
"chunk_size": 512,
"chunk_overlap": 0
},
"max_entities": 10,
"max_relations": 20,
"kg_extraction_prompt": "zero_shot_ner_kg_extraction",
"kg_extraction_config": {
"model": "ollama/sciphi/triplex",
"temperature": 0.7,
"temperature": 1.0,
"top_p": 1.0,
"top_k": 100,
"max_tokens_to_sample": 1024,
Expand Down
90 changes: 10 additions & 80 deletions r2r/examples/scripts/advanced_kg_cookbook.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import json
import os
import string

import fire
import requests
from bs4 import BeautifulSoup, Comment

from r2r import (
Document,
EntityType,
KGSearchSettings,
R2RBuilder,
R2RClient,
R2RPromptProvider,
Relation,
VectorSearchSettings,
generate_id_from_label,
update_kg_prompt,
)
from r2r.base.abstractions.llm import GenerationConfig


def escape_braces(text):
Expand Down Expand Up @@ -100,36 +95,16 @@ def main(

# Specify the entity types for the KG extraction prompt
entity_types = [
EntityType("ORGANIZATION"),
EntityType("COMPANY"),
EntityType("SCHOOL"),
EntityType("NON-PROFIT"),
EntityType("LOCATION"),
EntityType("CITY"),
EntityType("STATE"),
EntityType("COUNTRY"),
EntityType("PERSON"),
EntityType("POSITION"),
EntityType("DATE"),
EntityType("YEAR"),
EntityType("MONTH"),
EntityType("DAY"),
EntityType("BATCH"),
EntityType("OTHER"),
EntityType("QUANTITY"),
EntityType("EVENT"),
EntityType("INCORPORATION"),
EntityType("FUNDING_ROUND"),
EntityType("ACQUISITION"),
EntityType("LAUNCH"),
EntityType("INDUSTRY"),
EntityType("MEDIA"),
EntityType("EMAIL"),
EntityType("WEBSITE"),
EntityType("TWITTER"),
EntityType("LINKEDIN"),
EntityType("OTHER"),
EntityType("PRODUCT"),
]

# Specify the relations for the KG construction
Expand All @@ -149,54 +124,24 @@ def main(
# Product relations
Relation("PRODUCT"),
Relation("FEATURES"),
Relation("USES"),
Relation("USED_BY"),
Relation("TECHNOLOGY"),
# Additional relations
Relation("HAS"),
Relation("AS_OF"),
Relation("PARTICIPATED"),
Relation("ASSOCIATED"),
Relation("GROUP_PARTNER"),
Relation("ALIAS"),
]

client = R2RClient(base_url=base_url)
r2r_prompts = R2RPromptProvider()

# get the default extraction template
# note that 'local' templates omit the n-shot example
new_template = r2r_prompts.get_prompt(
(
"zero_shot_ner_kg_extraction_with_spec"
if local_mode
else "few_shot_ner_kg_extraction_with_spec"
),
{
"entity_types": "\n".join(
[str(entity.name) for entity in entity_types]
),
"relations": "\n".join(
[str(relation.name) for relation in relations]
),
"input": """\n{input}""",
},
)

# Escape all braces in the template, except for the {input} placeholder, for formatting
escaped_template = escape_braces(new_template).replace(
"""{{input}}""", """{input}"""
prompt_base = (
"zero_shot_ner_kg_extraction"
if local_mode
else "few_shot_ner_kg_extraction"
)

client.update_prompt(
(
"zero_shot_ner_kg_extraction"
if local_mode
else "few_shot_ner_kg_extraction"
),
template=escaped_template,
input_types={"input": "str"},
)
update_kg_prompt(client, r2r_prompts, prompt_base, entity_types, relations)

url_map = get_all_yc_co_directory_urls()

Expand Down Expand Up @@ -224,25 +169,10 @@ def main(

print(client.inspect_knowledge_graph(1_000)["results"])

new_template = r2r_prompts.get_prompt(
"kg_agent_with_spec",
{
"entity_types": "\n".join(
[str(entity.name) for entity in entity_types]
),
"relations": "\n".join(
[str(relation.name) for relation in relations]
),
"input": """\n{input}""",
},
)
if not local_mode:
# RAG client currently only works with powerful remote LLMs,
# we are working to expand support to local LLMs.
client.update_prompt(
"kg_agent",
template=new_template,
input_types={"input": "str"},

update_kg_prompt(
client, r2r_prompts, "kg_agent", entity_types, relations
)

result = client.search(
Expand Down
1 change: 0 additions & 1 deletion r2r/main/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
R2RDeleteRequest,
R2RDocumentChunksRequest,
R2RDocumentsOverviewRequest,
R2RExtractionRequest,
R2RIngestFilesRequest,
R2RLogsRequest,
R2RPrintRelationshipsRequest,
Expand Down
Loading

0 comments on commit c61cf66

Please sign in to comment.