Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge main to dev #786

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## [0.25.1] - 2024-05-15

### Fixed
- Honor `namespace` in `RedisVectorStoreDriver.query()`.
- Correctly set the `meta`, `score`, and `vector` fields of query result returned from `RedisVectorStoreDriver.query()`.
- Standardize behavior between omitted and empty actions list when initializing `ActionsSubtask`.

### Added
- Optional event batching on Event Listener Drivers.
- `id` field to all events.

### Changed
- Default behavior of Event Listener Drivers to batch events.
- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver.

## [0.25.0] - 2024-05-06

### Added
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ print(result)

The format for creating a vector index should be similar to the following:
```
FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA tag TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE
FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA namespace TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE
```

## OpenSearch Vector Store Driver
Expand Down
2 changes: 1 addition & 1 deletion griptape/config/openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class OpenAiStructureConfig(BaseStructureConfig):
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(
lambda: StructureGlobalDriversConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"),
image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"),
image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"),
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ class AmazonSqsEventListenerDriver(BaseEventListenerDriver):

def try_publish_event_payload(self, event_payload: dict) -> None:
self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
entries = [
{"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
for event_payload in event_payload_batch
]

self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ class AwsIotCoreEventListenerDriver(BaseEventListenerDriver):

def try_publish_event_payload(self, event_payload: dict) -> None:
self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload))

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))
32 changes: 25 additions & 7 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,37 @@
@define
class BaseEventListenerDriver(ABC):
futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)
batched: bool = field(default=True, kw_only=True)
batch_size: int = field(default=10, kw_only=True)

def publish_event(self, event: BaseEvent | dict) -> None:
if isinstance(event, dict):
self.futures_executor.submit(self._safe_try_publish_event_payload, event)
else:
self.futures_executor.submit(self._safe_try_publish_event_payload, event.to_dict())
_batch: list[dict] = field(default=Factory(list), kw_only=True)

@property
def batch(self) -> list[dict]:
return self._batch

def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None:
self.futures_executor.submit(self._safe_try_publish_event, event, flush)

@abstractmethod
def try_publish_event_payload(self, event_payload: dict) -> None:
...

def _safe_try_publish_event_payload(self, event_payload: dict) -> None:
@abstractmethod
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
...

def _safe_try_publish_event(self, event: BaseEvent | dict, flush: bool) -> None:
try:
self.try_publish_event_payload(event_payload)
event_payload = event if isinstance(event, dict) else event.to_dict()

if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size or flush:
self.try_publish_event_payload_batch(self.batch)
self._batch = []
return
else:
self.try_publish_event_payload(event_payload)
except Exception as e:
logger.error(e)
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,9 @@ def try_publish_event_payload(self, event_payload: dict) -> None:

response = requests.post(url=url, json=event_payload, headers=self.headers)
response.raise_for_status()

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events")

response = requests.post(url=url, json=event_payload_batch, headers=self.headers)
response.raise_for_status()
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ class WebhookEventListenerDriver(BaseEventListenerDriver):
def try_publish_event_payload(self, event_payload: dict) -> None:
response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers)
response.raise_for_status()

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
response = requests.post(url=self.webhook_url, json=event_payload_batch, headers=self.headers)
response.raise_for_status()
12 changes: 8 additions & 4 deletions griptape/drivers/vector/redis_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def upsert_vector(
mapping["vector"] = np.array(vector, dtype=np.float32).tobytes()
mapping["vec_string"] = bytes_vector

if namespace:
mapping["namespace"] = namespace

if meta:
mapping["metadata"] = json.dumps(meta)

Expand Down Expand Up @@ -120,8 +123,9 @@ def query(

vector = self.embedding_driver.embed_string(query)

filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*"
query_expression = (
Query(f"*=>[KNN {count or 10} @vector $vector as score]")
Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]")
.sort_by("score")
.return_fields("id", "score", "metadata", "vec_string")
.paging(0, count or 10)
Expand All @@ -134,15 +138,15 @@ def query(

query_results = []
for document in results:
metadata = getattr(document, "metadata", None)
metadata = json.loads(document.metadata) if hasattr(document, "metadata") else None
namespace = document.id.split(":")[0] if ":" in document.id else None
vector_id = document.id.split(":")[1] if ":" in document.id else document.id
vector_float_list = json.loads(document["vec_string"]) if include_vectors else None
vector_float_list = json.loads(document.vec_string) if include_vectors else None
query_results.append(
BaseVectorStoreDriver.QueryResult(
id=vector_id,
vector=vector_float_list,
score=float(document["score"]),
score=float(document.score),
meta=metadata,
namespace=namespace,
)
Expand Down
6 changes: 5 additions & 1 deletion griptape/events/base_event.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

import time
import uuid
from abc import ABC
from attr import define, field, Factory

from attr import Factory, define, field

from griptape.mixins import SerializableMixin


@define
class BaseEvent(SerializableMixin, ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
timestamp: float = field(default=Factory(lambda: time.time()), kw_only=True, metadata={"serializable": True})
6 changes: 3 additions & 3 deletions griptape/events/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ class EventListener:
event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True)
driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True)

def publish_event(self, event: BaseEvent) -> None:
def publish_event(self, event: BaseEvent, flush: bool = False) -> None:
event_types = self.event_types

if event_types is None or type(event) in event_types:
event_payload = self.handler(event)
if self.driver is not None:
if event_payload is not None and isinstance(event_payload, dict):
self.driver.publish_event(event_payload)
self.driver.publish_event(event_payload, flush=flush)
else:
self.driver.publish_event(event)
self.driver.publish_event(event, flush=flush)
7 changes: 4 additions & 3 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def remove_event_listener(self, event_listener: EventListener) -> None:
else:
raise ValueError("Event Listener not found.")

def publish_event(self, event: BaseEvent) -> None:
def publish_event(self, event: BaseEvent, flush: bool = False) -> None:
for event_listener in self.event_listeners:
event_listener.publish_event(event)
event_listener.publish_event(event, flush)

def context(self, task: BaseTask) -> dict[str, Any]:
return {"args": self.execution_args, "structure": self}
Expand All @@ -269,7 +269,8 @@ def after_run(self) -> None:
structure_id=self.id,
output_task_input=self.output_task.input,
output_task_output=self.output_task.output,
)
),
flush=True,
)

@abstractmethod
Expand Down
114 changes: 59 additions & 55 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,69 +198,73 @@ def __init_from_prompt(self, value: str) -> None:
if self.thought is None and len(thought_matches) > 0:
self.thought = thought_matches[-1]

if len(actions_matches) > 0:
try:
data = actions_matches[-1]
actions_list: list = json.loads(data, strict=False)
self.__parse_actions(actions_matches)

# If there are no actions to take but an answer is provided, set the answer as the output.
if len(self.actions) == 0 and self.output is None and len(answer_matches) > 0:
self.output = TextArtifact(answer_matches[-1])

def __parse_actions(self, actions_matches: list[str]) -> None:
if len(actions_matches) == 0:
return

try:
data = actions_matches[-1]
actions_list: list = json.loads(data, strict=False)

if isinstance(self.origin_task, ActionsSubtaskOriginMixin):
self.origin_task.actions_schema().validate(actions_list)

for action_object in actions_list:
# Load action name; throw exception if the key is not present
action_tag = action_object["tag"]

# Load action name; throw exception if the key is not present
action_name = action_object["name"]

# Load action method; throw exception if the key is not present
action_path = action_object["path"]

# Load optional input value; don't throw exceptions if key is not present
if "input" in action_object:
# The schema library has a bug, where something like `Or(str, None)` doesn't get
# correctly translated into JSON schema. For some optional input fields LLMs sometimes
# still provide null value, which trips up the validator. The temporary solution that
# works is to strip all key-values where value is null.
action_input = remove_null_values_in_dict_recursively(action_object["input"])
else:
action_input = {}

# Load the action itself
if isinstance(self.origin_task, ActionsSubtaskOriginMixin):
self.origin_task.actions_schema().validate(actions_list)

if not actions_list:
raise schema.SchemaError("Array item count 0 is less than minimum count of 1.")

for action_object in actions_list:
# Load action name; throw exception if the key is not present
action_tag = action_object["tag"]

# Load action name; throw exception if the key is not present
action_name = action_object["name"]

# Load action method; throw exception if the key is not present
action_path = action_object["path"]

# Load optional input value; don't throw exceptions if key is not present
if "input" in action_object:
# The schema library has a bug, where something like `Or(str, None)` doesn't get
# correctly translated into JSON schema. For some optional input fields LLMs sometimes
# still provide null value, which trips up the validator. The temporary solution that
# works is to strip all key-values where value is null.
action_input = remove_null_values_in_dict_recursively(action_object["input"])
else:
action_input = {}

# Load the action itself
if isinstance(self.origin_task, ActionsSubtaskOriginMixin):
tool = self.origin_task.find_tool(action_name)
else:
raise Exception(
"ActionSubtask must be attached to a Task that implements ActionSubtaskOriginMixin."
)

new_action = ActionsSubtask.Action(
tag=action_tag, name=action_name, path=action_path, input=action_input, tool=tool
tool = self.origin_task.find_tool(action_name)
else:
raise Exception(
"ActionSubtask must be attached to a Task that implements ActionSubtaskOriginMixin."
)

if new_action.tool:
if new_action.input:
self.__validate_action(new_action)
new_action = ActionsSubtask.Action(
tag=action_tag, name=action_name, path=action_path, input=action_input, tool=tool
)

# Don't forget to add it to the subtask actions list!
self.actions.append(new_action)
except SyntaxError as e:
self.structure.logger.error(f"Subtask {self.origin_task.id}\nSyntax error: {e}")
if new_action.tool:
if new_action.input:
self.__validate_action(new_action)

self.actions.append(self.__error_to_action(f"syntax error: {e}"))
except schema.SchemaError as e:
self.structure.logger.error(f"Subtask {self.origin_task.id}\nInvalid action JSON: {e}")
# Don't forget to add it to the subtask actions list!
self.actions.append(new_action)
except SyntaxError as e:
self.structure.logger.error(f"Subtask {self.origin_task.id}\nSyntax error: {e}")

self.actions.append(self.__error_to_action(f"Action JSON validation error: {e}"))
except Exception as e:
self.structure.logger.error(f"Subtask {self.origin_task.id}\nError parsing tool action: {e}")
self.actions.append(self.__error_to_action(f"syntax error: {e}"))
except schema.SchemaError as e:
self.structure.logger.error(f"Subtask {self.origin_task.id}\nInvalid action JSON: {e}")

self.actions.append(self.__error_to_action(f"Action input parsing error: {e}"))
elif self.output is None and len(answer_matches) > 0:
self.output = TextArtifact(answer_matches[-1])
self.actions.append(self.__error_to_action(f"Action JSON validation error: {e}"))
except Exception as e:
self.structure.logger.error(f"Subtask {self.origin_task.id}\nError parsing tool action: {e}")

self.actions.append(self.__error_to_action(f"Action input parsing error: {e}"))

def __error_to_action(self, error: str) -> Action:
return ActionsSubtask.Action(tag="error", name="error", input={"error": error})
Expand Down
7 changes: 6 additions & 1 deletion griptape/tokenizers/openai_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
class OpenAiTokenizer(BaseTokenizer):
DEFAULT_OPENAI_GPT_3_COMPLETION_MODEL = "gpt-3.5-turbo-instruct"
DEFAULT_OPENAI_GPT_3_CHAT_MODEL = "gpt-3.5-turbo"
DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4"
DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4o"
DEFAULT_ENCODING = "cl100k_base"
DEFAULT_MAX_TOKENS = 2049
DEFAULT_MAX_OUTPUT_TOKENS = 4096
TOKEN_OFFSET = 8

# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {
"gpt-4o": 128000,
"gpt-4-1106": 128000,
"gpt-4-32k": 32768,
"gpt-4": 8192,
Expand Down Expand Up @@ -85,6 +86,7 @@ def count_tokens(self, text: str | list[dict], model: Optional[str] = None) -> i
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4o-2024-05-13",
}:
tokens_per_message = 3
tokens_per_name = 1
Expand All @@ -96,6 +98,9 @@ def count_tokens(self, text: str | list[dict], model: Optional[str] = None) -> i
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
logging.info("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return self.count_tokens(text, model="gpt-3.5-turbo-0613")
elif "gpt-4o" in model:
logging.info("gpt-4o may update over time. Returning num tokens assuming gpt-4o-2024-05-13.")
return self.count_tokens(text, model="gpt-4o-2024-05-13")
elif "gpt-4" in model:
logging.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return self.count_tokens(text, model="gpt-4-0613")
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ nav:
- Web Scraper Drivers: "griptape-framework/drivers/web-scraper-drivers.md"
- Conversation Memory Drivers: "griptape-framework/drivers/conversation-memory-drivers.md"
- Event Listener Drivers: "griptape-framework/drivers/event-listener-drivers.md"
- Structure Run Drivers: "griptape-framework/drivers/structure-run-drivers.md"
- Data:
- Overview: "griptape-framework/data/index.md"
- Artifacts: "griptape-framework/data/artifacts.md"
Expand Down
Loading
Loading