Skip to content

Commit

Permalink
refactor: users can choose whether to initialize the agent RAG (#968)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Nov 4, 2024
1 parent 79973a2 commit 0a324b6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 38 deletions.
35 changes: 22 additions & 13 deletions src/config/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,29 @@ impl Agent {
let rag = if rag_path.exists() {
Some(Arc::new(Rag::load(config, "rag", &rag_path)?))
} else if !definition.documents.is_empty() {
println!("The agent has the documents, initializing RAG...");
let mut document_paths = vec![];
for path in &definition.documents {
if is_url(path) {
document_paths.push(path.to_string());
} else {
let new_path = safe_join_path(&functions_dir, path)
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?;
document_paths.push(new_path.display().to_string())
let mut ans = false;
if *IS_STDOUT_TERMINAL {
ans = Confirm::new("The agent has the documents, init RAG?")
.with_default(true)
.prompt()?;
}
if ans {
let mut document_paths = vec![];
for path in &definition.documents {
if is_url(path) {
document_paths.push(path.to_string());
} else {
let new_path = safe_join_path(&functions_dir, path)
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?;
document_paths.push(new_path.display().to_string())
}
}
let rag =
Rag::init(config, "rag", &rag_path, &document_paths, abort_signal).await?;
Some(Arc::new(rag))
} else {
None
}
Some(Arc::new(
Rag::init(config, "rag", &rag_path, &document_paths, abort_signal).await?,
))
} else {
None
};
Expand Down Expand Up @@ -375,7 +384,7 @@ fn init_variables(
.prompt()?;
variable.value = value;
} else {
bail!("Failed to init agent variables in the script mode.");
bail!("Failed to init agent variables in non-interactive mode");
}
}
}
Expand Down
33 changes: 8 additions & 25 deletions src/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ impl Rag {
doc_paths: &[String],
abort_signal: AbortSignal,
) -> Result<Self> {
debug!("init rag: {name}");
if !*IS_STDOUT_TERMINAL {
bail!("Failed to init rag in non-interactive mode");
}
println!("🚀 Initializing RAG...");
let (embedding_model, chunk_size, chunk_overlap) = Self::create_config(config)?;
let (reranker_model, top_k) = {
let config = config.read();
Expand All @@ -84,7 +87,6 @@ impl Rag {
if paths.is_empty() {
paths = add_documents()?;
};
debug!("doc paths: {paths:?}");
let loaders = config.read().document_loaders.clone();
let spinner = create_spinner("Starting").await;
tokio::select! {
Expand All @@ -98,7 +100,7 @@ impl Rag {
},
};
if rag.save()? {
println!("✨ Saved rag to '{}'.", save_path.display());
println!("✨ Saved RAG to '{}'.", save_path.display());
}
Ok(rag)
}
Expand Down Expand Up @@ -177,13 +179,7 @@ impl Rag {
if models.is_empty() {
bail!("No available embedding model");
}
if *IS_STDOUT_TERMINAL {
select_embedding_model(&models)?
} else {
let value = models[0].id();
println!("Select embedding model: {value}");
value
}
select_embedding_model(&models)?
}
};
let embedding_model = Model::retrieve_embedding(&config.read(), &embedding_model_id)?;
Expand All @@ -193,15 +189,7 @@ impl Rag {
println!("Set chunk size: {value}");
value
}
None => {
if *IS_STDOUT_TERMINAL {
set_chunk_size(&embedding_model)?
} else {
let value = embedding_model.default_chunk_size();
println!("Set chunk size: {value}");
value
}
}
None => set_chunk_size(&embedding_model)?,
};
let chunk_overlap = match chunk_overlap {
Some(value) => {
Expand All @@ -210,12 +198,7 @@ impl Rag {
}
None => {
let value = chunk_size / 20;
if *IS_STDOUT_TERMINAL {
set_chunk_overlay(value)?
} else {
println!("Set chunk overlay: {value}");
value
}
set_chunk_overlay(value)?
}
};

Expand Down

0 comments on commit 0a324b6

Please sign in to comment.