diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 09eb5988..361b8e21 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "langgraph-api", "fastapi", "google-genai", + "google-ai-generativelanguage>=0.6.18", ] diff --git a/backend/src/agent/graph.py b/backend/src/agent/graph.py index dae64b77..c3c78bf5 100644 --- a/backend/src/agent/graph.py +++ b/backend/src/agent/graph.py @@ -24,6 +24,7 @@ answer_instructions, ) from langchain_google_genai import ChatGoogleGenerativeAI +from google.ai.generativelanguage_v1beta.types import Tool as GenAITool from agent.utils import ( get_citations, get_research_topic, @@ -111,22 +112,25 @@ def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState: research_topic=state["search_query"], ) - # Uses the google genai client as the langchain client doesn't return grounding metadata - response = genai_client.models.generate_content( + llm = ChatGoogleGenerativeAI( model=configurable.query_generator_model, - contents=formatted_prompt, - config={ - "tools": [{"google_search": {}}], - "temperature": 0, - }, + temperature=0, + api_key=os.getenv("GEMINI_API_KEY"), ) + + response = llm.invoke( + formatted_prompt, + tools=[GenAITool(google_search={})], + ) + # resolve the urls to short urls for saving tokens and time resolved_urls = resolve_urls( - response.candidates[0].grounding_metadata.grounding_chunks, state["id"] + response.response_metadata["grounding_metadata"]["grounding_chunks"], + state["id"], ) # Gets the citations and adds them to the generated text citations = get_citations(response, resolved_urls) - modified_text = insert_citation_markers(response.text, citations) + modified_text = insert_citation_markers(response.content, citations) sources_gathered = [item for citation in citations for item in citation["segments"]] return { diff --git a/backend/src/agent/utils.py b/backend/src/agent/utils.py index d02c8d91..1fbadd16 100644 --- a/backend/src/agent/utils.py +++ b/backend/src/agent/utils.py @@ -25,7 +25,7 @@ def resolve_urls(urls_to_resolve: List[Any], id: int) -> Dict[str, str]: Ensures each original URL gets a consistent shortened form while maintaining uniqueness. """ prefix = f"https://vertexaisearch.cloud.google.com/id/" - urls = [site.web.uri for site in urls_to_resolve] + urls = [site["web"]["uri"] for site in urls_to_resolve] # Create a dictionary that maps each unique URL to its first occurrence index resolved_map = {} @@ -85,10 +85,9 @@ def get_citations(response, resolved_urls_map): containing formatted markdown links to the supporting web chunks. Args: - response: The response object from the Gemini model, expected to have - a structure including `candidates[0].grounding_metadata`. - It also relies on a `resolved_map` being available in its - scope to map chunk URIs to resolved URLs. + response: The response object from LangChain's ChatGoogleGenerativeAI, expected + to have a structure including response_metadata["grounding_metadata"]. + resolved_urls_map: A dictionary mapping original URLs to resolved URLs. Returns: list: A list of dictionaries, where each dictionary represents a citation @@ -102,59 +101,56 @@ def get_citations(response, resolved_urls_map): links for each grounding chunk. - "segment_string" (str): A concatenated string of all markdown- formatted links for the citation. - Returns an empty list if no valid candidates or grounding supports - are found, or if essential data is missing. + Returns an empty list if no valid grounding supports are found, or if + essential data is missing. """ citations = [] # Ensure response and necessary nested structures are present - if not response or not response.candidates: + if not response: return citations - candidate = response.candidates[0] if ( - not hasattr(candidate, "grounding_metadata") - or not candidate.grounding_metadata - or not hasattr(candidate.grounding_metadata, "grounding_supports") + "grounding_metadata" not in response.response_metadata + or not response.response_metadata["grounding_metadata"] + or "grounding_supports" not in response.response_metadata["grounding_metadata"] ): return citations - for support in candidate.grounding_metadata.grounding_supports: + grounding_metadata = response.response_metadata["grounding_metadata"] + for support in grounding_metadata["grounding_supports"]: citation = {} # Ensure segment information is present - if not hasattr(support, "segment") or support.segment is None: + if "segment" not in support or not support["segment"]: continue # Skip this support if segment info is missing start_index = ( - support.segment.start_index - if support.segment.start_index is not None + support["segment"]["start_index"] + if support["segment"]["start_index"] is not None else 0 ) # Ensure end_index is present to form a valid segment - if support.segment.end_index is None: + if support["segment"]["end_index"] is None: continue # Skip if end_index is missing, as it's crucial # Add 1 to end_index to make it an exclusive end for slicing/range purposes # (assuming the API provides an inclusive end_index) citation["start_index"] = start_index - citation["end_index"] = support.segment.end_index + citation["end_index"] = support["segment"]["end_index"] citation["segments"] = [] - if ( - hasattr(support, "grounding_chunk_indices") - and support.grounding_chunk_indices - ): - for ind in support.grounding_chunk_indices: + if "grounding_chunk_indices" in support and support["grounding_chunk_indices"]: + for ind in support["grounding_chunk_indices"]: try: - chunk = candidate.grounding_metadata.grounding_chunks[ind] - resolved_url = resolved_urls_map.get(chunk.web.uri, None) + chunk = grounding_metadata["grounding_chunks"][ind] + resolved_url = resolved_urls_map.get(chunk["web"]["uri"], None) citation["segments"].append( { - "label": chunk.web.title.split(".")[:-1][0], + "label": chunk["web"]["title"].split(".")[:-1][0], "short_url": resolved_url, - "value": chunk.web.uri, + "value": chunk["web"]["uri"], } ) except (IndexError, AttributeError, NameError):