From c50f2dc50e9350ccb8498b92fb6d2ee5b5f82919 Mon Sep 17 00:00:00 2001
From: Daniel Bevenius <daniel.bevenius@gmail.com>
Date: Wed, 18 Oct 2023 08:51:52 +0200
Subject: [PATCH] feat: add google serper (search) tool

This commit adds a new tool that can be used to search the internet for
information similar to BingSearch. The tool uses the Google Serper API
to perform the search.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
---
 .../examples/self_ask_with_google_search.rs   |  32 ++++
 crates/llm-chain/examples/google_serper.rs    |  12 ++
 .../src/tools/tools/google_serper.rs          | 137 ++++++++++++++++++
 crates/llm-chain/src/tools/tools/mod.rs       |   2 +
 4 files changed, 183 insertions(+)
 create mode 100644 crates/llm-chain-openai/examples/self_ask_with_google_search.rs
 create mode 100644 crates/llm-chain/examples/google_serper.rs
 create mode 100644 crates/llm-chain/src/tools/tools/google_serper.rs

diff --git a/crates/llm-chain-openai/examples/self_ask_with_google_search.rs b/crates/llm-chain-openai/examples/self_ask_with_google_search.rs
new file mode 100644
index 00000000..e9530805
--- /dev/null
+++ b/crates/llm-chain-openai/examples/self_ask_with_google_search.rs
@@ -0,0 +1,32 @@
+use llm_chain::{
+    agents::self_ask_with_search::{Agent, EarlyStoppingConfig},
+    executor,
+    tools::tools::GoogleSerper,
+};
+
+#[tokio::main(flavor = "current_thread")]
+async fn main() {
+    let executor = executor!().unwrap();
+    let serper_api_key = std::env::var("SERPER_API_KEY").unwrap();
+    let search_tool = GoogleSerper::new(serper_api_key);
+    let agent = Agent::new(
+        executor,
+        search_tool,
+        EarlyStoppingConfig {
+            max_iterations: Some(10),
+            max_time_elapsed_seconds: Some(30.0),
+        },
+    );
+    let (res, intermediate_steps) = agent
+        .run("What is the capital of the birthplace of Levy Mwanawasa?")
+        .await
+        .unwrap();
+    println!(
+        "Are followup questions needed here: {}",
+        agent.build_agent_scratchpad(&intermediate_steps)
+    );
+    println!(
+        "Agent final answer: {}",
+        res.return_values.get("output").unwrap()
+    );
+}
diff --git a/crates/llm-chain/examples/google_serper.rs b/crates/llm-chain/examples/google_serper.rs
new file mode 100644
index 00000000..848f0544
--- /dev/null
+++ b/crates/llm-chain/examples/google_serper.rs
@@ -0,0 +1,12 @@
+use llm_chain::tools::{tools::GoogleSerper, Tool};
+
+#[tokio::main(flavor = "current_thread")]
+async fn main() {
+    let serper_api_key = std::env::var("SERPER_API_KEY").unwrap();
+    let serper = GoogleSerper::new(serper_api_key);
+    let result = serper
+        .invoke_typed(&"Who was the inventor of Catan?".into())
+        .await
+        .unwrap();
+    println!("Best answer from Google Serper: {}", result.result);
+}
diff --git a/crates/llm-chain/src/tools/tools/google_serper.rs b/crates/llm-chain/src/tools/tools/google_serper.rs
new file mode 100644
index 00000000..db93d036
--- /dev/null
+++ b/crates/llm-chain/src/tools/tools/google_serper.rs
@@ -0,0 +1,137 @@
+use async_trait::async_trait;
+use reqwest::Method;
+use serde::{Deserialize, Serialize};
+use thiserror::Error;
+
+use crate::tools::{Describe, Tool, ToolDescription, ToolError};
+
+pub struct GoogleSerper {
+    api_key: String,
+}
+
+impl GoogleSerper {
+    pub fn new(api_key: String) -> Self {
+        Self { api_key }
+    }
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct GoogleSerperInput {
+    pub query: String,
+}
+
+impl From<&str> for GoogleSerperInput {
+    fn from(value: &str) -> Self {
+        Self {
+            query: value.into(),
+        }
+    }
+}
+
+impl From<String> for GoogleSerperInput {
+    fn from(value: String) -> Self {
+        Self { query: value }
+    }
+}
+
+impl Describe for GoogleSerperInput {
+    fn describe() -> crate::tools::Format {
+        vec![("query", "Search query to find necessary information").into()].into()
+    }
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct GoogleSerperOutput {
+    pub result: String,
+}
+
+impl From<String> for GoogleSerperOutput {
+    fn from(value: String) -> Self {
+        Self { result: value }
+    }
+}
+
+impl From<GoogleSerperOutput> for String {
+    fn from(val: GoogleSerperOutput) -> Self {
+        val.result
+    }
+}
+
+impl Describe for GoogleSerperOutput {
+    fn describe() -> crate::tools::Format {
+        vec![(
+            "result",
+            "Information retrieved from the internet that should answer your query",
+        )
+            .into()]
+        .into()
+    }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct SiteLinks {
+    title: String,
+    link: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct Organic {
+    title: String,
+    link: String,
+    snippet: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct GoogleSerperResult {
+    organic: Vec<Organic>,
+}
+
+#[derive(Debug, Error)]
+pub enum GoogleSerperError {
+    #[error("No search results were returned")]
+    NoResults,
+    #[error(transparent)]
+    Yaml(#[from] serde_yaml::Error),
+    #[error(transparent)]
+    Request(#[from] reqwest::Error),
+}
+
+impl ToolError for GoogleSerperError {}
+
+#[async_trait]
+impl Tool for GoogleSerper {
+    type Input = GoogleSerperInput;
+
+    type Output = GoogleSerperOutput;
+
+    type Error = GoogleSerperError;
+
+    async fn invoke_typed(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
+        let client = reqwest::Client::new();
+        let response = client
+            .request(Method::GET, "https://google.serper.dev/search")
+            .query(&[("q", &input.query)])
+            .header("X-API-KEY", self.api_key.clone())
+            .send()
+            .await?
+            .json::<GoogleSerperResult>()
+            .await?;
+        let answer = response
+            .organic
+            .first()
+            .ok_or(GoogleSerperError::NoResults)?
+            .snippet
+            .clone();
+        Ok(answer.into())
+    }
+
+    fn description(&self) -> ToolDescription {
+        ToolDescription::new(
+            "Google search",
+            "Useful for when you need to answer questions about current events. Input should be a search query.",
+            "Use this to get information about current events.",
+            GoogleSerperInput::describe(),
+            GoogleSerperOutput::describe(),
+        )
+    }
+}
diff --git a/crates/llm-chain/src/tools/tools/mod.rs b/crates/llm-chain/src/tools/tools/mod.rs
index dd30af5b..cb98541e 100644
--- a/crates/llm-chain/src/tools/tools/mod.rs
+++ b/crates/llm-chain/src/tools/tools/mod.rs
@@ -4,11 +4,13 @@
 mod bash;
 mod bing_search;
 mod exit;
+mod google_serper;
 mod python;
 mod vectorstore;
 pub use bash::{BashTool, BashToolError, BashToolInput, BashToolOutput};
 pub use bing_search::{BingSearch, BingSearchError, BingSearchInput, BingSearchOutput};
 pub use exit::{ExitTool, ExitToolError, ExitToolInput, ExitToolOutput};
+pub use google_serper::{GoogleSerper, GoogleSerperError, GoogleSerperInput, GoogleSerperOutput};
 pub use python::{PythonTool, PythonToolError, PythonToolInput, PythonToolOutput};
 pub use vectorstore::{
     VectorStoreTool, VectorStoreToolError, VectorStoreToolInput, VectorStoreToolOutput,