-
Notifications
You must be signed in to change notification settings - Fork 669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: implement a tool permission store #1516
Changes from all commits
b43a03a
89ec854
5c8d830
8ea4de2
389e171
d3e0fac
fb2a6b4
8a2affe
af71cf9
0b1e77c
bb266f5
865386e
6606c1f
1acbada
a271b83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
use crate::message::ToolRequest; | ||
use anyhow::Result; | ||
use blake3::Hasher; | ||
use chrono::Utc; | ||
use etcetera::{choose_app_strategy, AppStrategy}; | ||
use serde::{Deserialize, Serialize}; | ||
use std::collections::HashMap; | ||
use std::time::Duration; | ||
use std::{fs::File, path::PathBuf}; | ||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)] | ||
pub struct ToolPermissionRecord { | ||
tool_name: String, | ||
allowed: bool, | ||
context_hash: String, // Hash of the tool's arguments/context to differentiate similar calls | ||
#[serde(skip_serializing_if = "Option::is_none")] // Don't serialize if None | ||
readable_context: Option<String>, // Add this field | ||
timestamp: i64, | ||
expiry: Option<i64>, // Optional expiry timestamp | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct ToolPermissionStore { | ||
permissions: HashMap<String, Vec<ToolPermissionRecord>>, | ||
version: u32, // For future schema migrations | ||
#[serde(skip)] // Don't serialize this field | ||
permissions_dir: PathBuf, | ||
} | ||
|
||
impl Default for ToolPermissionStore { | ||
fn default() -> Self { | ||
Self::new() | ||
} | ||
} | ||
|
||
impl ToolPermissionStore { | ||
pub fn new() -> Self { | ||
let permissions_dir = choose_app_strategy(crate::config::APP_STRATEGY.clone()) | ||
.map(|strategy| strategy.config_dir()) | ||
.unwrap_or_else(|_| PathBuf::from(".config/goose")); | ||
|
||
Self { | ||
permissions: HashMap::new(), | ||
version: 1, | ||
permissions_dir, | ||
} | ||
} | ||
|
||
pub fn load() -> Result<Self> { | ||
let store = Self::new(); | ||
let file_path = store.permissions_dir.join("tool_permissions.json"); | ||
|
||
if !file_path.exists() { | ||
return Ok(store); | ||
} | ||
|
||
let file = File::open(file_path)?; | ||
let mut permissions: ToolPermissionStore = serde_json::from_reader(file)?; | ||
permissions.permissions_dir = store.permissions_dir; | ||
|
||
// Clean up expired entries on load | ||
permissions.cleanup_expired()?; | ||
|
||
Ok(permissions) | ||
} | ||
|
||
pub fn save(&self) -> anyhow::Result<()> { | ||
std::fs::create_dir_all(&self.permissions_dir)?; | ||
|
||
let path = self.permissions_dir.join("tool_permissions.json"); | ||
let temp_path = path.with_extension("tmp"); | ||
|
||
// Write complete content to temporary file | ||
let content = serde_json::to_string_pretty(self)?; | ||
std::fs::write(&temp_path, &content)?; | ||
|
||
// Atomically rename temp file to target file | ||
std::fs::rename(temp_path, path)?; | ||
|
||
Ok(()) | ||
} | ||
|
||
pub fn check_permission(&self, tool_request: &ToolRequest) -> Option<bool> { | ||
let context_hash = self.hash_tool_context(tool_request); | ||
let tool_call = tool_request.tool_call.as_ref().unwrap(); | ||
let key = format!("{}:{}", tool_call.name, context_hash); | ||
|
||
self.permissions.get(&key).and_then(|records| { | ||
records | ||
.iter() | ||
.filter(|record| record.expiry.is_none_or(|exp| exp > Utc::now().timestamp())) | ||
.last() | ||
.map(|record| record.allowed) | ||
}) | ||
} | ||
|
||
pub fn record_permission( | ||
&mut self, | ||
tool_request: &ToolRequest, | ||
allowed: bool, | ||
expiry_duration: Option<Duration>, | ||
) -> anyhow::Result<()> { | ||
let context_hash = self.hash_tool_context(tool_request); | ||
let tool_call = tool_request.tool_call.as_ref().unwrap(); | ||
let key = format!("{}:{}", tool_call.name, context_hash); | ||
|
||
let record = ToolPermissionRecord { | ||
tool_name: tool_call.name.clone(), | ||
allowed, | ||
context_hash, | ||
readable_context: Some(tool_request.to_readable_string()), | ||
timestamp: Utc::now().timestamp(), | ||
expiry: expiry_duration.map(|d| Utc::now().timestamp() + d.as_secs() as i64), | ||
}; | ||
|
||
self.permissions.entry(key).or_default().push(record); | ||
|
||
self.save()?; | ||
Ok(()) | ||
} | ||
|
||
fn hash_tool_context(&self, tool_request: &ToolRequest) -> String { | ||
// Create a hash of the tool's arguments to differentiate similar calls | ||
// This helps identify when the same tool is being used in a different context | ||
let mut hasher = Hasher::new(); | ||
hasher.update( | ||
serde_json::to_string(&tool_request.tool_call.as_ref().unwrap().arguments) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to hash on some argument params? I am wondering whether we will have low hit rate if we hash all, including argument param and value, like write file, the name can be different, but they are super similar There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar thought - I was thinking to hash by tool name only at first, but this would lump all bash commands in one hash. Any ideas about how to do the hash at the right level of granularity? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, kind of tricky, introducing some cosine similarity check may be overkilled. I am wondering whether we can have a set of normalized keys? for example in bash case, we have "command":write, "file_path": xxx, and we maintain a set containing "file_path", if it is in the normalized key set, like "file_path", we replace the value with "normalized_value" before hashing. Even with such case, we cannot handle all argument keys, we can just start with bash? how do you think? |
||
.unwrap_or_default() | ||
.as_bytes(), | ||
); | ||
hasher.finalize().to_hex().to_string() | ||
} | ||
|
||
pub fn cleanup_expired(&mut self) -> anyhow::Result<()> { | ||
let now = Utc::now().timestamp(); | ||
let mut changed = false; | ||
|
||
self.permissions.retain(|_, records| { | ||
records.retain(|record| record.expiry.is_none_or(|exp| exp > now)); | ||
changed = changed || records.is_empty(); | ||
!records.is_empty() | ||
}); | ||
|
||
if changed { | ||
self.save()?; | ||
} | ||
Ok(()) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ use super::detect_read_only_tools; | |
use super::Agent; | ||
use crate::agents::capabilities::Capabilities; | ||
use crate::agents::extension::{ExtensionConfig, ExtensionResult}; | ||
use crate::agents::ToolPermissionStore; | ||
use crate::config::Config; | ||
use crate::config::ExperimentManager; | ||
use crate::message::{Message, ToolRequest}; | ||
|
@@ -28,6 +29,7 @@ use mcp_core::prompt::Prompt; | |
use mcp_core::protocol::GetPromptResult; | ||
use mcp_core::{tool::Tool, Content}; | ||
use serde_json::{json, Value}; | ||
use std::time::Duration; | ||
|
||
const MAX_TRUNCATION_ATTEMPTS: usize = 3; | ||
const ESTIMATE_FACTOR_DECAY: f32 = 0.9; | ||
|
@@ -265,19 +267,43 @@ impl Agent for TruncateAgent { | |
match mode.as_str() { | ||
"approve" => { | ||
let mut read_only_tools = Vec::new(); | ||
// Process each tool request sequentially with confirmation | ||
if ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? { | ||
read_only_tools = detect_read_only_tools(&capabilities, tool_requests.clone()).await; | ||
let mut needs_confirmation = Vec::<&ToolRequest>::new(); | ||
|
||
// First check permissions for all tools | ||
let store = ToolPermissionStore::load()?; | ||
for request in tool_requests.iter() { | ||
if let Ok(tool_call) = request.tool_call.clone() { | ||
if let Some(allowed) = store.check_permission(request) { | ||
if allowed { | ||
let output = capabilities.dispatch_tool_call(tool_call).await; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to output something to let users know the permission is skipped due to cache? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. possibly, it could be a bit noisy imo |
||
message_tool_response = message_tool_response.with_tool_response( | ||
request.id.clone(), | ||
output, | ||
); | ||
} else { | ||
needs_confirmation.push(request); | ||
} | ||
} else { | ||
needs_confirmation.push(request); | ||
} | ||
} | ||
} | ||
for request in &tool_requests { | ||
|
||
// Only check read-only status for tools needing confirmation | ||
if !needs_confirmation.is_empty() && ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? { | ||
read_only_tools = detect_read_only_tools(&capabilities, needs_confirmation.clone()).await; | ||
} | ||
|
||
// Process remaining tools that need confirmation | ||
for request in &needs_confirmation { | ||
if let Ok(tool_call) = request.tool_call.clone() { | ||
// Skip confirmation if the tool_call.name is in the read_only_tools list | ||
if read_only_tools.contains(&tool_call.name) { | ||
let output = capabilities.dispatch_tool_call(tool_call).await; | ||
message_tool_response = message_tool_response.with_tool_response( | ||
request.id.clone(), | ||
output, | ||
); | ||
message_tool_response = message_tool_response.with_tool_response( | ||
request.id.clone(), | ||
output, | ||
); | ||
} else { | ||
let confirmation = Message::user().with_tool_confirmation_request( | ||
request.id.clone(), | ||
|
@@ -289,9 +315,12 @@ impl Agent for TruncateAgent { | |
|
||
// Wait for confirmation response through the channel | ||
let mut rx = self.confirmation_rx.lock().await; | ||
// Loop the recv until we have a matched req_id due to potential duplicate messages. | ||
while let Some((req_id, confirmed)) = rx.recv().await { | ||
if req_id == request.id { | ||
// Store the user's response with 30-day expiration | ||
let mut store = ToolPermissionStore::load()?; | ||
store.record_permission(request, confirmed, Some(Duration::from_secs(30 * 24 * 60 * 60)))?; | ||
|
||
if confirmed { | ||
// User approved - dispatch the tool call | ||
let output = capabilities.dispatch_tool_call(tool_call).await; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to clean up the expired entries?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, good call