Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement a tool permission store #1516

Merged
merged 15 commits into from
Mar 7, 2025
26 changes: 26 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ aws-sdk-bedrockruntime = "1.72.0"
# For GCP Vertex AI provider auth
jsonwebtoken = "9.3.1"

# Added blake3 hashing library as a dependency
blake3 = "1.5"

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

Expand Down
2 changes: 2 additions & 0 deletions crates/goose/src/agents/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod capabilities;
pub mod extension;
mod factory;
mod permission_judge;
mod permission_store;
mod reference;
mod truncate;

Expand All @@ -11,3 +12,4 @@ pub use capabilities::Capabilities;
pub use extension::ExtensionConfig;
pub use factory::{register_agent, AgentFactory};
pub use permission_judge::detect_read_only_tools;
pub use permission_store::ToolPermissionStore;
149 changes: 149 additions & 0 deletions crates/goose/src/agents/permission_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use crate::message::ToolRequest;
use anyhow::Result;
use blake3::Hasher;
use chrono::Utc;
use etcetera::{choose_app_strategy, AppStrategy};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use std::{fs::File, path::PathBuf};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolPermissionRecord {
tool_name: String,
allowed: bool,
context_hash: String, // Hash of the tool's arguments/context to differentiate similar calls
#[serde(skip_serializing_if = "Option::is_none")] // Don't serialize if None
readable_context: Option<String>, // Add this field
timestamp: i64,
expiry: Option<i64>, // Optional expiry timestamp
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ToolPermissionStore {
permissions: HashMap<String, Vec<ToolPermissionRecord>>,
version: u32, // For future schema migrations
#[serde(skip)] // Don't serialize this field
permissions_dir: PathBuf,
}

impl Default for ToolPermissionStore {
fn default() -> Self {
Self::new()
}
}

impl ToolPermissionStore {
pub fn new() -> Self {
let permissions_dir = choose_app_strategy(crate::config::APP_STRATEGY.clone())
.map(|strategy| strategy.config_dir())
.unwrap_or_else(|_| PathBuf::from(".config/goose"));

Self {
permissions: HashMap::new(),
version: 1,
permissions_dir,
}
}

pub fn load() -> Result<Self> {
let store = Self::new();
let file_path = store.permissions_dir.join("tool_permissions.json");

if !file_path.exists() {
return Ok(store);
}

let file = File::open(file_path)?;
let mut permissions: ToolPermissionStore = serde_json::from_reader(file)?;
permissions.permissions_dir = store.permissions_dir;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to clean up the expired entries?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good call


// Clean up expired entries on load
permissions.cleanup_expired()?;

Ok(permissions)
}

pub fn save(&self) -> anyhow::Result<()> {
std::fs::create_dir_all(&self.permissions_dir)?;

let path = self.permissions_dir.join("tool_permissions.json");
let temp_path = path.with_extension("tmp");

// Write complete content to temporary file
let content = serde_json::to_string_pretty(self)?;
std::fs::write(&temp_path, &content)?;

// Atomically rename temp file to target file
std::fs::rename(temp_path, path)?;

Ok(())
}

pub fn check_permission(&self, tool_request: &ToolRequest) -> Option<bool> {
let context_hash = self.hash_tool_context(tool_request);
let tool_call = tool_request.tool_call.as_ref().unwrap();
let key = format!("{}:{}", tool_call.name, context_hash);

self.permissions.get(&key).and_then(|records| {
records
.iter()
.filter(|record| record.expiry.is_none_or(|exp| exp > Utc::now().timestamp()))
.last()
.map(|record| record.allowed)
})
}

pub fn record_permission(
&mut self,
tool_request: &ToolRequest,
allowed: bool,
expiry_duration: Option<Duration>,
) -> anyhow::Result<()> {
let context_hash = self.hash_tool_context(tool_request);
let tool_call = tool_request.tool_call.as_ref().unwrap();
let key = format!("{}:{}", tool_call.name, context_hash);

let record = ToolPermissionRecord {
tool_name: tool_call.name.clone(),
allowed,
context_hash,
readable_context: Some(tool_request.to_readable_string()),
timestamp: Utc::now().timestamp(),
expiry: expiry_duration.map(|d| Utc::now().timestamp() + d.as_secs() as i64),
};

self.permissions.entry(key).or_default().push(record);

self.save()?;
Ok(())
}

fn hash_tool_context(&self, tool_request: &ToolRequest) -> String {
// Create a hash of the tool's arguments to differentiate similar calls
// This helps identify when the same tool is being used in a different context
let mut hasher = Hasher::new();
hasher.update(
serde_json::to_string(&tool_request.tool_call.as_ref().unwrap().arguments)
Copy link
Collaborator

@yingjiehe-xyz yingjiehe-xyz Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to hash on some argument params? I am wondering whether we will have low hit rate if we hash all, including argument param and value, like write file, the name can be different, but they are super similar

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar thought - I was thinking to hash by tool name only at first, but this would lump all bash commands in one hash.

Any ideas about how to do the hash at the right level of granularity?

Copy link
Collaborator

@yingjiehe-xyz yingjiehe-xyz Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, kind of tricky, introducing some cosine similarity check may be overkilled. I am wondering whether we can have a set of normalized keys? for example in bash case, we have "command":write, "file_path": xxx, and we maintain a set containing "file_path", if it is in the normalized key set, like "file_path", we replace the value with "normalized_value" before hashing. Even with such case, we cannot handle all argument keys, we can just start with bash? how do you think?

.unwrap_or_default()
.as_bytes(),
);
hasher.finalize().to_hex().to_string()
}

pub fn cleanup_expired(&mut self) -> anyhow::Result<()> {
let now = Utc::now().timestamp();
let mut changed = false;

self.permissions.retain(|_, records| {
records.retain(|record| record.expiry.is_none_or(|exp| exp > now));
changed = changed || records.is_empty();
!records.is_empty()
});

if changed {
self.save()?;
}
Ok(())
}
}
47 changes: 38 additions & 9 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use super::detect_read_only_tools;
use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::agents::ToolPermissionStore;
use crate::config::Config;
use crate::config::ExperimentManager;
use crate::message::{Message, ToolRequest};
Expand All @@ -28,6 +29,7 @@ use mcp_core::prompt::Prompt;
use mcp_core::protocol::GetPromptResult;
use mcp_core::{tool::Tool, Content};
use serde_json::{json, Value};
use std::time::Duration;

const MAX_TRUNCATION_ATTEMPTS: usize = 3;
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
Expand Down Expand Up @@ -265,19 +267,43 @@ impl Agent for TruncateAgent {
match mode.as_str() {
"approve" => {
let mut read_only_tools = Vec::new();
// Process each tool request sequentially with confirmation
if ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? {
read_only_tools = detect_read_only_tools(&capabilities, tool_requests.clone()).await;
let mut needs_confirmation = Vec::<&ToolRequest>::new();

// First check permissions for all tools
let store = ToolPermissionStore::load()?;
for request in tool_requests.iter() {
if let Ok(tool_call) = request.tool_call.clone() {
if let Some(allowed) = store.check_permission(request) {
if allowed {
let output = capabilities.dispatch_tool_call(tool_call).await;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to output something to let users know the permission is skipped due to cache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly, it could be a bit noisy imo

message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
} else {
needs_confirmation.push(request);
}
} else {
needs_confirmation.push(request);
}
}
}
for request in &tool_requests {

// Only check read-only status for tools needing confirmation
if !needs_confirmation.is_empty() && ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? {
read_only_tools = detect_read_only_tools(&capabilities, needs_confirmation.clone()).await;
}

// Process remaining tools that need confirmation
for request in &needs_confirmation {
if let Ok(tool_call) = request.tool_call.clone() {
// Skip confirmation if the tool_call.name is in the read_only_tools list
if read_only_tools.contains(&tool_call.name) {
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
} else {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
Expand All @@ -289,9 +315,12 @@ impl Agent for TruncateAgent {

// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
// Loop the recv until we have a matched req_id due to potential duplicate messages.
while let Some((req_id, confirmed)) = rx.recv().await {
if req_id == request.id {
// Store the user's response with 30-day expiration
let mut store = ToolPermissionStore::load()?;
store.record_permission(request, confirmed, Some(Duration::from_secs(30 * 24 * 60 * 60)))?;

if confirmed {
// User approved - dispatch the tool call
let output = capabilities.dispatch_tool_call(tool_call).await;
Expand Down
16 changes: 16 additions & 0 deletions crates/goose/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ pub struct ToolRequest {
pub tool_call: ToolResult<ToolCall>,
}

impl ToolRequest {
pub fn to_readable_string(&self) -> String {
match &self.tool_call {
Ok(tool_call) => {
format!(
"Tool: {}, Args: {}",
tool_call.name,
serde_json::to_string_pretty(&tool_call.arguments)
.unwrap_or_else(|_| "<<invalid json>>".to_string())
)
}
Err(e) => format!("Invalid tool call: {}", e),
}
}
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResponse {
Expand Down
Loading