Skip to content

Commit

Permalink
Added TavilyWebSearchDriver
Browse files Browse the repository at this point in the history
  • Loading branch information
william-price01 committed Sep 16, 2024
1 parent 37d5582 commit 6246696
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 96 deletions.
7 changes: 7 additions & 0 deletions docs/griptape-framework/drivers/src/web_search_drivers_4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os

from griptape.drivers import TavilyWebSearchDriver

driver = TavilyWebSearchDriver(api_key=os.environ["TAVILY_API_KEY"])

driver.search("griptape ai")
9 changes: 9 additions & 0 deletions docs/griptape-framework/drivers/src/web_search_drivers_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from griptape.drivers import DuckDuckGoWebSearchDriver
from griptape.structures import Agent
from griptape.tools import PromptSummaryTool, WebSearchTool

agent = Agent(
tools=[WebSearchTool(web_search_driver=DuckDuckGoWebSearchDriver()), PromptSummaryTool(off_prompt=False)],
)

agent.run("Give me some websites with information about AI frameworks.")
34 changes: 26 additions & 8 deletions docs/griptape-framework/drivers/web-search-drivers.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
search:
boost: 2
boost: 2
---

## Overview
Expand All @@ -9,7 +9,19 @@ Web Search Drivers can be used to search for links from a search query. They are

* `search()` searches the web and returns a [ListArtifact](../../reference/griptape/artifacts/list_artifact.md) that contains JSON-serializable [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s with the search results.

## Vector Store Drivers
You can use Web Search Drivers with structures:

```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_5.py"
```
Or use them independently:

```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_3.py"
```


## Web Search Drivers

### Google

Expand All @@ -21,12 +33,6 @@ Example using `GoogleWebSearchDriver` directly:
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_1.py"
```

Example of using `GoogleWebSearchDriver` with an agent:

```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_2.py"
```

### DuckDuckGo

!!! info
Expand All @@ -39,3 +45,15 @@ Example of using `DuckDuckGoWebSearchDriver` directly:
```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_3.py"
```

### Tavily
!!! info
This driver requires the `drivers-web-search-tavily` [extra](../index.md#extras), and a Tavily [API-KEY](https://app.tavily.com).

The [TavilyWebSearchDriver](../../reference/griptape/drivers/web_search/tavily_web_search_driver.md) uses the [tavily-python](https://github.com/tavily-ai/tavily-python) SDK for web searching.

Example of using `TavilyWebSearchDriver` directly:

```python
--8<-- "docs/griptape-framework/drivers/src/web_search_drivers_4.py"
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from .web_search.base_web_search_driver import BaseWebSearchDriver
from .web_search.google_web_search_driver import GoogleWebSearchDriver
from .web_search.duck_duck_go_web_search_driver import DuckDuckGoWebSearchDriver
from .web_search.tavily_web_search_driver import TavilyWebSearchDriver

from .event_listener.base_event_listener_driver import BaseEventListenerDriver
from .event_listener.amazon_sqs_event_listener_driver import AmazonSqsEventListenerDriver
Expand Down Expand Up @@ -213,6 +214,7 @@
"BaseWebSearchDriver",
"GoogleWebSearchDriver",
"DuckDuckGoWebSearchDriver",
"TavilyWebSearchDriver",
"BaseEventListenerDriver",
"AmazonSqsEventListenerDriver",
"WebhookEventListenerDriver",
Expand Down
2 changes: 0 additions & 2 deletions griptape/drivers/web_search/base_web_search_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
@define
class BaseWebSearchDriver(ABC):
results_count: int = field(default=5, kw_only=True)
language: str = field(default="en", kw_only=True)
country: str = field(default="us", kw_only=True)

@abstractmethod
def search(self, query: str, **kwargs) -> ListArtifact: ...
2 changes: 2 additions & 0 deletions griptape/drivers/web_search/duck_duck_go_web_search_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

@define
class DuckDuckGoWebSearchDriver(BaseWebSearchDriver):
language: str = field(default="en", kw_only=True)
country: str = field(default="us", kw_only=True)
client: DDGS = field(default=Factory(lambda: import_optional_dependency("duckduckgo_search").DDGS()), kw_only=True)

def search(self, query: str, **kwargs) -> ListArtifact:
Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/web_search/google_web_search_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
class GoogleWebSearchDriver(BaseWebSearchDriver):
api_key: str = field(kw_only=True)
search_id: str = field(kw_only=True)
language: str = field(default="en", kw_only=True)
country: str = field(default="us", kw_only=True)

def search(self, query: str, **kwargs) -> ListArtifact:
return ListArtifact([TextArtifact(json.dumps(result)) for result in self._search_google(query, **kwargs)])
Expand Down
37 changes: 37 additions & 0 deletions griptape/drivers/web_search/tavily_web_search_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.drivers import BaseWebSearchDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from tavily import TavilyClient


@define
class TavilyWebSearchDriver(BaseWebSearchDriver):
api_key: str = field(kw_only=True)
client: TavilyClient = field(
default=Factory(lambda self: import_optional_dependency("tavily").TavilyClient(self.api_key), takes_self=True),
kw_only=True,
)

def search(self, query: str, **kwargs) -> ListArtifact:
try:
response = self.client.search(query, max_results=self.results_count, **kwargs)
results = response["results"]
return ListArtifact(
[
TextArtifact(
json.dumps({"title": result["title"], "url": result["url"], "content": result["content"]})
)
for result in results
]
)
except Exception as e:
raise Exception(f"Error searching '{query}' with Tavily: {e}") from e
18 changes: 17 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6246696

Please sign in to comment.