Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Groq chat response handling and persistence #747

Merged
merged 19 commits into from
Feb 21, 2025
8 changes: 8 additions & 0 deletions backend/src/server/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::services::{
deepseek::DeepSeekService,
github_issue::GitHubService,
groq::GroqService,
oauth::{github::GitHubOAuth, scramble::ScrambleOAuth, OAuthConfig},
openrouter::OpenRouterService,
solver::SolverService,
Expand Down Expand Up @@ -28,6 +29,7 @@ pub struct AppState {
pub scramble_oauth: Arc<ScrambleOAuth>,
pub pool: PgPool,
pub frontend_url: String,
pub groq: Arc<GroqService>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -129,6 +131,11 @@ pub fn configure_app_with_config(pool: PgPool, config: Option<AppConfig>) -> Rou

let chat_model = Arc::new(DeepSeekService::with_base_url(api_key, base_url));

// Initialize Groq service
let groq = Arc::new(GroqService::new(
env::var("GROQ_API_KEY").expect("GROQ_API_KEY must be set"),
));

let tools = create_tools();

let ws_state = Arc::new(WebSocketState::new(
Expand Down Expand Up @@ -177,6 +184,7 @@ pub fn configure_app_with_config(pool: PgPool, config: Option<AppConfig>) -> Rou
scramble_oauth,
pool: pool.clone(),
frontend_url: config.frontend_url,
groq,
};

// Create the main router
Expand Down
174 changes: 122 additions & 52 deletions backend/src/server/handlers/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub async fn start_repo_chat(
info!("Created conversation with id: {}", conversation.id);

// Create initial message with repos metadata
let message = chat_db
let _message = chat_db
.create_message(&CreateMessageRequest {
conversation_id: conversation.id,
user_id: user_id.clone(),
Expand All @@ -101,11 +101,53 @@ pub async fn start_repo_chat(
)
})?;

info!("Created message with id: {}", message.id);
info!("Created message with id: {}", _message.id);

// Convert message to Groq format
let messages = vec![json!({
"role": "user",
"content": request.message
})];

// Get Groq response
let (ai_response, _) = state
.groq
.chat_with_history(messages, false)
.await
.map_err(|e| {
error!("Failed to get Groq response: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get AI response: {}", e),
)
})?;

// Save AI response
let ai_message = chat_db
.create_message(&CreateMessageRequest {
conversation_id: conversation.id,
user_id: user_id.clone(),
role: "assistant".to_string(),
content: ai_response,
metadata: Some(json!({
"repos": request.repos
})),
tool_calls: None,
})
.await
.map_err(|e| {
error!("Failed to save AI response: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to save AI response: {}", e),
)
})?;

info!("Created AI message with id: {}", ai_message.id);

Ok(Json(StartChatResponse {
id: conversation.id.to_string(),
initial_message: message.content,
initial_message: _message.content,
}))
}

Expand All @@ -128,51 +170,27 @@ pub async fn send_message(
"No session found. Please log in.".to_string(),
));
};
info!("Using user_id: {}", user_id);

// If no repos provided in the request, try to get them from the conversation's first message
let metadata = if let Some(repos) = request.repos {
info!("Using repos from request: {:?}", repos);
Some(json!({ "repos": repos }))
} else {
// Get the first message of the conversation to find the repos
let messages = chat_db
.get_conversation_messages(request.conversation_id)
.await
.map_err(|e| {
error!("Failed to fetch conversation messages: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to fetch conversation messages: {}", e),
)
})?;

// Find the first message with repos metadata
let first_message_repos = messages.iter().find_map(|msg| {
msg.metadata.as_ref().and_then(|meta| {
meta.get("repos")
.and_then(|repos| repos.as_array())
.map(|repos| repos.to_owned())
})
});

if let Some(repos) = first_message_repos {
info!("Using repos from first message: {:?}", repos);
Some(json!({ "repos": repos }))
} else {
info!("No repos found in request or first message");
None
}
};
// Get conversation history
let messages = chat_db
.get_conversation_messages(request.conversation_id)
.await
.map_err(|e| {
error!("Failed to get conversation history: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get conversation history: {}", e),
)
})?;

// Create message
let message = chat_db
// Create user message
let _message = chat_db
.create_message(&CreateMessageRequest {
conversation_id: request.conversation_id,
user_id: user_id.clone(),
role: "user".to_string(),
content: request.message.clone(),
metadata,
metadata: request.repos.clone().map(|repos| json!({ "repos": repos })),
tool_calls: None,
})
.await
Expand All @@ -184,11 +202,58 @@ pub async fn send_message(
)
})?;

info!("Created message with id: {}", message.id);
// Convert messages to Groq format
let mut chat_messages: Vec<serde_json::Value> = messages
.iter()
.map(|msg| {
json!({
"role": msg.role,
"content": msg.content
})
})
.collect();

// Add current message
chat_messages.push(json!({
"role": "user",
"content": request.message
}));

// Get AI response with full history
let (ai_response, _) = state
.groq
.chat_with_history(chat_messages, false)
.await
.map_err(|e| {
error!("Failed to get Groq response: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get AI response: {}", e),
)
})?;

// Save AI response
let ai_message = chat_db
.create_message(&CreateMessageRequest {
conversation_id: request.conversation_id,
user_id,
role: "assistant".to_string(),
content: ai_response.clone(),
metadata: request.repos.clone().map(|repos| json!({ "repos": repos })),
tool_calls: None,
})
.await
.map_err(|e| {
error!("Failed to save AI response: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to save AI response: {}", e),
)
})?;

Ok(Json(SendMessageResponse {
id: message.id.to_string(),
message: message.content,
id: ai_message.id.to_string(),
message: ai_response,
}))
}

Expand All @@ -197,7 +262,8 @@ pub async fn get_conversation_messages(
State(state): State<AppState>,
Path(conversation_id): Path<Uuid>,
) -> Result<Json<Vec<Message>>, (StatusCode, String)> {
info!("Fetching messages for conversation: {}", conversation_id);
// Create chat database service
let chat_db = ChatDatabaseService::new(state.pool);

// Get user info from session
let user_id = if let Some(session_cookie) = cookies.get(SESSION_COOKIE_NAME) {
Expand All @@ -208,24 +274,28 @@ pub async fn get_conversation_messages(
"No session found. Please log in.".to_string(),
));
};
info!("Using user_id: {}", user_id);

// Create chat database service
let chat_db = ChatDatabaseService::new(state.pool);

// Get messages
let messages = chat_db
.get_conversation_messages(conversation_id)
.await
.map_err(|e| {
error!("Failed to fetch conversation messages: {:?}", e);
error!("Failed to get messages: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to fetch conversation messages: {}", e),
format!("Failed to get messages: {}", e),
)
})?;

info!("Found {} messages", messages.len());
// Verify user has access to this conversation
if let Some(first_message) = messages.first() {
if first_message.user_id != user_id {
return Err((
StatusCode::FORBIDDEN,
"You do not have access to this conversation".to_string(),
));
}
}

Ok(Json(messages))
}
55 changes: 33 additions & 22 deletions backend/src/server/services/groq/service.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::{Context, Result};
use reqwest::{Client, ClientBuilder};
use serde_json::Value;
use std::pin::Pin;
use std::time::Duration;
use tokio_stream::Stream;
Expand Down Expand Up @@ -44,35 +45,19 @@ impl GroqService {
base_url,
}
}
}

#[async_trait::async_trait]
impl Gateway for GroqService {
fn metadata(&self) -> GatewayMetadata {
GatewayMetadata {
name: "Groq".to_string(),
openai_compatible: true,
supported_features: vec!["chat".to_string(), "streaming".to_string()],
default_model: "mixtral-8x7b-32768".to_string(),
available_models: vec![
"llama-3.1-8b-instant".to_string(),
"llama-3.3-70b-versatile".to_string(),
"mixtral-8x7b-32768".to_string(),
],
}
}

async fn chat(&self, prompt: String, use_reasoner: bool) -> Result<(String, Option<String>)> {
pub async fn chat_with_history(
&self,
messages: Vec<Value>,
use_reasoner: bool,
) -> Result<(String, Option<String>)> {
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&serde_json::json!({
"model": self.metadata().default_model,
"messages": [{
"role": "user",
"content": prompt
}],
"messages": messages,
"temperature": if use_reasoner { 0.0 } else { 0.7 },
"stream": false
}))
Expand Down Expand Up @@ -100,6 +85,32 @@ impl Gateway for GroqService {

Ok((content, None))
}
}

#[async_trait::async_trait]
impl Gateway for GroqService {
fn metadata(&self) -> GatewayMetadata {
GatewayMetadata {
name: "Groq".to_string(),
openai_compatible: true,
supported_features: vec!["chat".to_string(), "streaming".to_string()],
default_model: "llama-3.1-8b-instant".to_string(),
available_models: vec![
"llama-3.1-8b-instant".to_string(),
"llama-3.3-70b-versatile".to_string(),
],
}
}

async fn chat(&self, prompt: String, use_reasoner: bool) -> Result<(String, Option<String>)> {
// Convert single prompt into messages format
let messages = vec![serde_json::json!({
"role": "user",
"content": prompt
})];

self.chat_with_history(messages, use_reasoner).await
}

async fn chat_stream(
&self,
Expand Down
Loading
Loading