Skip to content

Commit

Permalink
Improve JSON extraction, add extraction to Task Memory
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Aug 5, 2024
1 parent fe53c41 commit af3d0f2
Show file tree
Hide file tree
Showing 22 changed files with 235 additions and 63 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `AstraDbVectorStoreDriver` to support DataStax Astra DB as a vector store.
- Ability to set custom schema properties on Tool Activities via `extra_schema_properties`.
- `extract_json` and `extract_csv` Activities to `TaskMemoryClient`.
- `extract_json_namespace` and `extract_csv_namespace` methods to `TaskMemory`.

### Changed
- **BREAKING**: Split parameter `JsonExtractionEngine.template_generator` into `system_template_generator` and `user_template_generator`.
- **BREAKING**: Split parameter `CsvExtractionEngine.template_generator` into `system_template_generator` and `user_template_generator`.
- **BREAKING**: Split `JsonExtractionEngine.extract` into `extract_text` and `extract_artifacts`.
- **BREAKING**: Split `CsvExtractionEngine.extract` into `extract_text` and `extract_artifacts`.
- Parse json from LLM output before loading in `JsonExtractionEngine`.

### Fixed
- Missing implementations of `csv_extraction_engine` and `json_extraction_engine` in `TextArtifactStorage`.

## [0.29.0] - 2024-07-30

Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/engines/extraction-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Charlie is 40 and lives in Texas.
"""

# Extract CSV rows using the engine
result = csv_engine.extract(sample_text, column_names=["name", "age", "location"])
result = csv_engine.extract_text(sample_text, column_names=["name", "age", "location"])

for row in result.value:
print(row.to_text())
Expand Down Expand Up @@ -73,7 +73,7 @@ user_schema = Schema(
).json_schema("UserSchema")

# Extract data using the engine
result = json_engine.extract(sample_json_text, template_schema=user_schema)
result = json_engine.extract_text(sample_json_text, template_schema=user_schema)

for artifact in result.value:
print(artifact.value)
Expand Down
12 changes: 9 additions & 3 deletions griptape/engines/extraction/base_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from attrs import Attribute, Factory, define, field

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.chunkers import BaseChunker, TextChunker

if TYPE_CHECKING:
from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.artifacts import ErrorArtifact
from griptape.drivers import BasePromptDriver
from griptape.rules import Ruleset

Expand Down Expand Up @@ -45,10 +46,15 @@ def min_response_tokens(self) -> int:
)

@abstractmethod
def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact,
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact | ErrorArtifact: ...

def extract_text(
self, text: str, *, rulesets: Optional[list[Ruleset]] = None, **kwargs
) -> ListArtifact | ErrorArtifact:
return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs)
42 changes: 30 additions & 12 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

@define
class CsvExtractionEngine(BaseExtractionEngine):
template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv_extraction.j2")), kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/csv_extraction/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/csv_extraction/user.j2")), kw_only=True)

def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact,
*,
rulesets: Optional[list[Ruleset]] = None,
column_names: Optional[list[str]] = None,
Expand All @@ -33,7 +34,7 @@ def extract(
try:
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
cast(list[TextArtifact], artifacts.value),
column_names,
[],
rulesets=rulesets,
Expand All @@ -60,32 +61,49 @@ def _extract_rec(
rulesets: Optional[list[Ruleset]] = None,
) -> list[CsvRowArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
full_text = self.template_generator.render(
system_prompt = self.system_template_generator.render(
column_names=column_names,
text=artifacts_text,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
text=artifacts_text,
)

if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens:
if (
self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt)
>= self.min_response_tokens
):
rows.extend(
self.text_to_csv_rows(
self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value,
self.prompt_driver.run(
PromptStack(
messages=[
Message(system_prompt, role=Message.SYSTEM_ROLE),
Message(user_prompt, role=Message.USER_ROLE),
]
)
).value,
column_names,
),
)

return rows
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.template_generator.render(
column_names=column_names,
partial_text = self.user_template_generator.render(

Check warning on line 93 in griptape/engines/extraction/csv_extraction_engine.py

View check run for this annotation

Codecov / codecov/patch

griptape/engines/extraction/csv_extraction_engine.py#L93

Added line #L93 was not covered by tests
text=chunks[0].value,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)

rows.extend(
self.text_to_csv_rows(
self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value,
self.prompt_driver.run(
PromptStack(
messages=[
Message(system_prompt, role=Message.SYSTEM_ROLE),
Message(partial_text, role=Message.USER_ROLE),
]
)
).value,
column_names,
),
)
Expand Down
54 changes: 41 additions & 13 deletions griptape/engines/extraction/json_extraction_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import re
from typing import TYPE_CHECKING, Optional, cast

from attrs import Factory, define, field
Expand All @@ -17,11 +18,16 @@

@define
class JsonExtractionEngine(BaseExtractionEngine):
template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json_extraction.j2")), kw_only=True)
JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

def extract(
system_template_generator: J2 = field(
default=Factory(lambda: J2("engines/json_extraction/system.j2")), kw_only=True
)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/json_extraction/user.j2")), kw_only=True)

def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact,
*,
rulesets: Optional[list[Ruleset]] = None,
template_schema: Optional[list[dict]] = None,
Expand All @@ -34,7 +40,7 @@ def extract(

return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
cast(list[TextArtifact], artifacts.value),
json_schema,
[],
rulesets=rulesets,
Expand All @@ -45,7 +51,12 @@ def extract(
return ErrorArtifact(f"error extracting JSON: {e}")

def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]:
return [TextArtifact(json.dumps(e)) for e in json.loads(json_input)]
json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

if json_matches:
return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])]
else:
return []

def _extract_rec(
self,
Expand All @@ -55,31 +66,48 @@ def _extract_rec(
rulesets: Optional[list[Ruleset]] = None,
) -> list[TextArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
full_text = self.template_generator.render(
system_prompt = self.system_template_generator.render(
json_template_schema=json_template_schema,
text=artifacts_text,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
text=artifacts_text,
)

if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens:
if (
self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt)
>= self.min_response_tokens
):
extractions.extend(
self.json_to_text_artifacts(
self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value,
self.prompt_driver.run(
PromptStack(
messages=[
Message(system_prompt, role=Message.SYSTEM_ROLE),
Message(user_prompt, role=Message.USER_ROLE),
]
)
).value
),
)

return extractions
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.template_generator.render(
template_schema=json_template_schema,
partial_text = self.user_template_generator.render(

Check warning on line 97 in griptape/engines/extraction/json_extraction_engine.py

View check run for this annotation

Codecov / codecov/patch

griptape/engines/extraction/json_extraction_engine.py#L97

Added line #L97 was not covered by tests
text=chunks[0].value,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)

extractions.extend(
self.json_to_text_artifacts(
self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value,
self.prompt_driver.run(
PromptStack(
messages=[
Message(system_prompt, role=Message.SYSTEM_ROLE),
Message(partial_text, role=Message.USER_ROLE),
]
)
).value,
),
)

Expand Down
8 changes: 7 additions & 1 deletion griptape/memory/task/storage/base_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from attrs import define

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact
from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact


@define
Expand All @@ -25,3 +25,9 @@ def summarize(self, namespace: str) -> TextArtifact | InfoArtifact: ...

@abstractmethod
def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: ...

@abstractmethod
def extract_csv(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact: ...

@abstractmethod
def extract_json(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact: ...
6 changes: 6 additions & 0 deletions griptape/memory/task/storage/blob_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ def summarize(self, namespace: str) -> InfoArtifact:

def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact:
return InfoArtifact("can't query artifacts")

def extract_csv(self, namespace: str) -> InfoArtifact:
return InfoArtifact("can't extract csv")

Check warning on line 37 in griptape/memory/task/storage/blob_artifact_storage.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/storage/blob_artifact_storage.py#L37

Added line #L37 was not covered by tests

def extract_json(self, namespace: str) -> InfoArtifact:
return InfoArtifact("can't extract json")

Check warning on line 40 in griptape/memory/task/storage/blob_artifact_storage.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/storage/blob_artifact_storage.py#L40

Added line #L40 was not covered by tests
13 changes: 13 additions & 0 deletions griptape/memory/task/storage/text_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from griptape.memory.task.storage import BaseArtifactStorage

if TYPE_CHECKING:
from griptape.artifacts import ErrorArtifact
from griptape.drivers import BaseVectorStoreDriver
from griptape.engines import BaseSummaryEngine, CsvExtractionEngine, JsonExtractionEngine

Expand Down Expand Up @@ -45,6 +46,18 @@ def summarize(self, namespace: str) -> TextArtifact:

return self.summary_engine.summarize_artifacts(self.load_artifacts(namespace))

def extract_csv(self, namespace: str) -> ListArtifact | ErrorArtifact:
if self.csv_extraction_engine is None:
raise ValueError("Csv extraction engine is not set.")

Check warning on line 51 in griptape/memory/task/storage/text_artifact_storage.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/storage/text_artifact_storage.py#L51

Added line #L51 was not covered by tests

return self.csv_extraction_engine.extract_artifacts(self.load_artifacts(namespace))

def extract_json(self, namespace: str) -> ListArtifact | ErrorArtifact:
if self.json_extraction_engine is None:
raise ValueError("Json extraction engine is not set.")

Check warning on line 57 in griptape/memory/task/storage/text_artifact_storage.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/storage/text_artifact_storage.py#L57

Added line #L57 was not covered by tests

return self.json_extraction_engine.extract_artifacts(self.load_artifacts(namespace))

def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact:
if self.rag_engine is None:
raise ValueError("rag_engine is not set")
Expand Down
16 changes: 16 additions & 0 deletions griptape/memory/task/task_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,19 @@ def query_namespace(self, namespace: str, query: str) -> BaseArtifact:
return storage.query(namespace=namespace, query=query, metadata=self.namespace_metadata.get(namespace))
else:
return InfoArtifact("Can't find memory content")

def extract_json_namespace(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact:
storage = self.namespace_storage.get(namespace)

if storage:
return storage.extract_json(namespace)
else:
return ErrorArtifact("Can't find memory content")

Check warning on line 149 in griptape/memory/task/task_memory.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/task_memory.py#L149

Added line #L149 was not covered by tests

def extract_csv_namespace(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact:
storage = self.namespace_storage.get(namespace)

if storage:
return storage.extract_csv(namespace)
else:
return ErrorArtifact("Can't find memory content")

Check warning on line 157 in griptape/memory/task/task_memory.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/task_memory.py#L157

Added line #L157 was not covered by tests
2 changes: 1 addition & 1 deletion griptape/tasks/extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ def extraction_engine(self) -> BaseExtractionEngine:
return self._extraction_engine

def run(self) -> ListArtifact | ErrorArtifact:
return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args)
return self.extraction_engine.extract_text(self.input.to_text(), rulesets=self.all_rulesets, **self.args)
7 changes: 7 additions & 0 deletions griptape/templates/engines/csv_extraction/system.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Don't add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes.
Column Names: """{{ column_names }}"""

{% if rulesets %}

{{ rulesets }}
{% endif %}
4 changes: 4 additions & 0 deletions griptape/templates/engines/csv_extraction/user.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Extract information from the Text based on the Column Names and output it as a CSV file.
Text: """{{ text }}"""

Answer:
11 changes: 0 additions & 11 deletions griptape/templates/engines/extraction/csv_extraction.j2

This file was deleted.

6 changes: 6 additions & 0 deletions griptape/templates/engines/json_extraction/system.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Extraction Template JSON Schema: """{{ json_template_schema }}"""

{% if rulesets %}

{{ rulesets }}
{% endif %}
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
Text: """{{ text }}"""

Extraction Template JSON Schema: """{{ json_template_schema }}"""

Extract information from the Text based on the Extraction Template JSON Schema into an array of JSON objects.
{% if rulesets %}

{{ rulesets }}
{% endif %}
Text: """{{ text }}"""

JSON array:
Loading

0 comments on commit af3d0f2

Please sign in to comment.