Skip to content

Commit

Permalink
Merge branch 'dev' into refactor/loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 25, 2024
2 parents e4af29f + 53bc38b commit cb0b868
Show file tree
Hide file tree
Showing 78 changed files with 1,169 additions and 738 deletions.
3 changes: 3 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ updates:
update-types:
- "minor"
- "patch"
allow:
- dependency-type: production
- dependency-type: development
- package-ecosystem: "github-actions"
directory: "/"
schedule:
Expand Down
45 changes: 28 additions & 17 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,36 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- `BaseFileLoader` for Loaders that load from a path.
- `BaseLoader.fetch()` method for fetching data from a source.
- `BaseLoader.parse()` method for parsing fetched data.
- `BaseFileManager.encoding` to specify the encoding when loading and saving files.
- `BaseWebScraperDriver.extract_page()` method for extracting data from an already scraped web page.
- `TextLoaderRetrievalRagModule.chunker` for specifying the chunking strategy.
- `file_utils.get_mime_type` utility for getting the MIME type of a file.
- `Workflow.input_tasks` and `Workflow.output_tasks` to access the input and output tasks of a Workflow.
- Ability to pass nested list of `Tasks` to `Structure.tasks` allowing for more complex declarative Structure definitions.


## Added
- Parameter `pipeline_task` on `HuggingFacePipelinePromptDriver` for creating different types of `Pipeline`s.
- `TavilyWebSearchDriver` to integrate Tavily's web search capabilities.

### Changed
- **BREAKING**: Removed `BaseFileManager.default_loader` and `BaseFileManager.loaders`.
- **BREAKING**: Loaders no longer chunk data, use a Chunker to chunk the data.
- **BREAKING**: Removed `fileutils.load_file` and `fileutils.load_files`.
- **BREAKING**: Removed `loaders-dataframe` and `loaders-audio` extras as they are no longer needed.
- **BREKING**: `TextLoader`, `PdfLoader`, `ImageLoader`, and `AudioLoader` now take a `str | PathLike` instead of `bytes`.
- **BREAKING**: Removed `DataframeLoader`.
- `LocalFileManagerDriver.workdir` is now optional.
- `filetype` is now a core dependency.
- `FileManagerTool` now uses `filetype` for more accurate file type detection.
- `BaseFileLoader.load_file()` will now either return a `TextArtifact` or a `BlobArtifact` depending on whether `BaseFileManager.encoding` is set.
- **BREAKING**: Renamed parameters on several classes to `client`:
- `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`.
- `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`.
- `bedrock_client` on `AmazonBedrockTitanEmbeddingDriver`.
- `bedrock_client` on `AmazonBedrockImageGenerationDriver`.
- `bedrock_client` on `AmazonBedrockImageQueryDriver`.
- `bedrock_client` on `AmazonBedrockPromptDriver`.
- `sagemaker_client` on `AmazonSageMakerJumpstartEmbeddingDriver`.
- `sagemaker_client` on `AmazonSageMakerJumpstartPromptDriver`.
- `sqs_client` on `AmazonSqsEventListenerDriver`.
- `iotdata_client` on `AwsIotCoreEventListenerDriver`.
- `s3_client` on `AmazonS3FileManagerDriver`.
- `s3_client` on `AwsS3Tool`.
- `iam_client` on `AwsIamTool`.
- `pusher_client` on `PusherEventListenerDriver`.
- `mq` on `MarqoVectorStoreDriver`.
- `model_client` on `GooglePromptDriver`.
- `model_client` on `GoogleTokenizer`.
- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `pipeline`.
- Several places where API clients are initialized are now lazy loaded.


## [0.32.0] - 2024-09-17

Expand Down
4 changes: 4 additions & 0 deletions docs/griptape-cloud/data-sources/create-data-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ You can [create a Data Source in the Griptape Cloud console](https://cloud.gript

You can scrape and ingest a single, public web page by providing a URL. If you wish to scrape multiple pages, you must create multiple Data Sources. However, you can then add all of the pages to the same Knowledge Base if you wish to access all the pages together.

### Amazon S3

You can connect Amazon S3 buckets, objects, and prefixes by providing their S3 URI(s). Supported file extensions include .pdf, .csv, .md, and most text-based file types.

### Google Drive

You can ingest documents and spreadsheets stored in a Google Drive account. We support all standard file formats such as text, markdown, spreadsheets, and presentations.
Expand Down
7 changes: 7 additions & 0 deletions docs/griptape-framework/drivers/src/web_search_drivers_4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os

from griptape.drivers import TavilyWebSearchDriver

driver = TavilyWebSearchDriver(api_key=os.environ["TAVILY_API_KEY"])

driver.search("griptape ai")
9 changes: 9 additions & 0 deletions docs/griptape-framework/drivers/src/web_search_drivers_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from griptape.drivers import DuckDuckGoWebSearchDriver
from griptape.structures import Agent
from griptape.tools import PromptSummaryTool, WebSearchTool

agent = Agent(
tools=[WebSearchTool(web_search_driver=DuckDuckGoWebSearchDriver()), PromptSummaryTool(off_prompt=False)],
)

agent.run("Give me some websites with information about AI frameworks.")
60 changes: 52 additions & 8 deletions docs/griptape-framework/drivers/web-search-drivers.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
search:
boost: 2
boost: 2
---

## Overview
Expand All @@ -9,7 +9,47 @@ Web Search Drivers can be used to search for links from a search query. They are

* `search()` searches the web and returns a [ListArtifact](../../reference/griptape/artifacts/list_artifact.md) that contains JSON-serializable [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s with the search results.

## Vector Store Drivers
You can use Web Search Drivers with [Structures](../structures/agents.md):

```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_5.py"
```
```
ToolkitTask 45a53f1024494baab41a1f10a67017b1
Output: Here are some websites with information about AI
frameworks:
1. [The Top 16 AI Frameworks and Libraries: A Beginner's Guide -
DataCamp](https://www.datacamp.com/blog/top-ai-frameworks-and-lib
raries)
2. [AI Frameworks: Top Types To Adopt in 2024 -
Splunk](https://www.splunk.com/en_us/blog/learn/ai-frameworks.htm
l)
3. [Top AI Frameworks in 2024: A Review -
BairesDev](https://www.bairesdev.com/blog/ai-frameworks/)
4. [The Top 16 AI Frameworks and Libraries - AI
Slackers](https://aislackers.com/the-top-16-ai-frameworks-and-lib
raries/)
5. [Top AI Frameworks in 2024: Artificial Intelligence Frameworks
Comparison - Clockwise
Software](https://clockwise.software/blog/artificial-intelligence
-framework/)
```
Or use them independently:

```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_3.py"
```
```
{"title": "The Top 16 AI Frameworks and Libraries: A Beginner's Guide", "url": "https://www.datacamp.com/blog/top-ai-frameworks-and-libraries", "description": "PyTorch. Torch is an open-source machine learning library known for its dynamic computational graph and is favored by researchers. The framework is excellent for prototyping and experimentation. Moreover, it's empowered by growing community support, with tools like PyTorch being built on the library."}
{"title": "Top 11 AI Frameworks and Tools in 2024 | Fively | 5ly.co", "url": "https://5ly.co/blog/best-ai-frameworks/", "description": "Discover the top 11 modern artificial intelligence tools and frameworks to build robust architectures for your AI-powered apps. ... - Some advanced use cases may need further fine-tuning. Caffe 2. Now we move on to deep learning tools and frameworks. The first one is Caffe 2: an open-source deep learning framework with modularity and speed in ..."}
{"title": "The Top 16 AI Frameworks and Libraries | AI Slackers", "url": "https://aislackers.com/the-top-16-ai-frameworks-and-libraries/", "description": "Experiment with different frameworks to find the one that aligns with your needs and goals as a data practitioner. Embrace the world of AI frameworks, and embark on a journey of building smarter software with confidence. Discover the top AI frameworks and libraries like PyTorch, Scikit-Learn, TensorFlow, Keras, LangChain, and more."}
```


## Web Search Drivers

### Google

Expand All @@ -21,12 +61,6 @@ Example using `GoogleWebSearchDriver` directly:
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_1.py"
```

Example of using `GoogleWebSearchDriver` with an agent:

```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_2.py"
```

### DuckDuckGo

!!! info
Expand All @@ -39,3 +73,13 @@ Example of using `DuckDuckGoWebSearchDriver` directly:
```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_3.py"
```

### Tavily
!!! info
This driver requires the `drivers-web-search-tavily` [extra](../index.md#extras), and a Tavily [api key](https://app.tavily.com).

Example of using `TavilyWebSearchDriver` directly:

```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_4.py"
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from .web_search.base_web_search_driver import BaseWebSearchDriver
from .web_search.google_web_search_driver import GoogleWebSearchDriver
from .web_search.duck_duck_go_web_search_driver import DuckDuckGoWebSearchDriver
from .web_search.tavily_web_search_driver import TavilyWebSearchDriver

from .event_listener.base_event_listener_driver import BaseEventListenerDriver
from .event_listener.amazon_sqs_event_listener_driver import AmazonSqsEventListenerDriver
Expand Down Expand Up @@ -213,6 +214,7 @@
"BaseWebSearchDriver",
"GoogleWebSearchDriver",
"DuckDuckGoWebSearchDriver",
"TavilyWebSearchDriver",
"BaseEventListenerDriver",
"AmazonSqsEventListenerDriver",
"WebhookEventListenerDriver",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import Optional

import openai
from attrs import Factory, define, field
from attrs import define, field

from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.drivers import BaseAudioTranscriptionDriver
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -17,12 +18,11 @@ class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver):
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
),
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
additional_params = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
from mypy_boto3_bedrock import BedrockClient

from griptape.tokenizers.base_tokenizer import BaseTokenizer

Expand All @@ -26,7 +28,7 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
`search_query` when querying your vector DB to find relevant documents.
session: Optionally provide custom `boto3.Session`.
tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
bedrock_client: Optionally provide custom `bedrock-runtime` client.
client: Optionally provide custom `bedrock-runtime` client.
"""

DEFAULT_MODEL = "cohere.embed-english-v3"
Expand All @@ -38,15 +40,16 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
kw_only=True,
)
_client: BedrockClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> BedrockClient:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"input_type": self.input_type, "texts": [chunk]}

response = self.bedrock_client.invoke_model(
response = self.client.invoke_model(
body=json.dumps(payload),
modelId=self.model,
accept="*/*",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
from mypy_boto3_bedrock import BedrockClient

from griptape.tokenizers.base_tokenizer import BaseTokenizer

Expand All @@ -23,7 +25,7 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
model: Embedding model name. Defaults to DEFAULT_MODEL.
tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
session: Optionally provide custom `boto3.Session`.
bedrock_client: Optionally provide custom `bedrock-runtime` client.
client: Optionally provide custom `bedrock-runtime` client.
"""

DEFAULT_MODEL = "amazon.titan-embed-text-v1"
Expand All @@ -34,15 +36,16 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
kw_only=True,
)
_client: BedrockClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> BedrockClient:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"inputText": chunk}

response = self.bedrock_client.invoke_model(
response = self.client.invoke_model(
body=json.dumps(payload),
modelId=self.model,
accept="application/json",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional

from attrs import Factory, define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
from mypy_boto3_sagemaker import SageMakerClient


@define
class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
sagemaker_client: Any = field(
default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True),
kw_only=True,
)
endpoint: str = field(kw_only=True, metadata={"serializable": True})
custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
_client: SageMakerClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> SageMakerClient:
return self.session.client("sagemaker-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"text_inputs": chunk, "mode": "embedding"}

endpoint_response = self.sagemaker_client.invoke_endpoint(
endpoint_response = self.client.invoke_endpoint(
EndpointName=self.endpoint,
ContentType="application/json",
Body=json.dumps(payload).encode("utf-8"),
Expand Down
Loading

0 comments on commit cb0b868

Please sign in to comment.