From 7e6e740d69e702c8c5fc5bf64a42595e46e0107e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 31 Dec 2024 22:11:19 -0500 Subject: [PATCH] improve chunking --- agixt/providers/rotation.py | 158 ++++++++++++++++++++---------------- 1 file changed, 88 insertions(+), 70 deletions(-) diff --git a/agixt/providers/rotation.py b/agixt/providers/rotation.py index e8981994f28a..5eaeefe9b41f 100644 --- a/agixt/providers/rotation.py +++ b/agixt/providers/rotation.py @@ -14,18 +14,38 @@ def score_chunk(chunk: str, keywords: set) -> int: return score -def chunk_content(text: str, chunk_size: int) -> List[str]: +def chunk_content(text: str, chunk_size: int, max_tokens: int = 60000) -> List[str]: + """ + Split content into chunks while respecting both character and token limits. + + Args: + text: Text to chunk + chunk_size: Target size for each chunk in characters + max_tokens: Maximum tokens allowed for processing (default: 60000 for Deepseek) + """ doc = nlp(text) sentences = list(doc.sents) content_chunks = [] chunk = [] chunk_len = 0 + chunk_text = "" + total_len = 0 + keywords = set(extract_keywords(doc=doc, limit=10)) + for sentence in sentences: sentence_tokens = len(sentence) + # Estimate tokens (rough approximation: 4 characters per token) + estimated_total_tokens = (total_len + len(str(sentence))) // 4 + + if estimated_total_tokens > max_tokens: + break + if chunk_len + sentence_tokens > chunk_size and chunk: chunk_text = " ".join(token.text for token in chunk) - content_chunks.append((score_chunk(chunk_text, keywords), chunk_text)) + score = score_chunk(chunk_text, keywords) + content_chunks.append((score, chunk_text)) + total_len += len(chunk_text) chunk = [] chunk_len = 0 @@ -34,11 +54,23 @@ def chunk_content(text: str, chunk_size: int) -> List[str]: if chunk: chunk_text = " ".join(token.text for token in chunk) - content_chunks.append((score_chunk(chunk_text, keywords), chunk_text)) + score = score_chunk(chunk_text, keywords) + content_chunks.append((score, chunk_text)) - # Sort the chunks by their score in descending order before returning them + # Sort by score and take only enough chunks to stay under token limit content_chunks.sort(key=lambda x: x[0], reverse=True) - return [chunk_text for score, chunk_text in content_chunks] + result_chunks = [] + total_len = 0 + + for score, chunk_text in content_chunks: + # Estimate tokens for this chunk + chunk_tokens = len(chunk_text) // 4 + if total_len + chunk_tokens > max_tokens: + break + result_chunks.append(chunk_text) + total_len += chunk_tokens + + return result_chunks class RotationProvider: @@ -69,89 +101,75 @@ async def _analyze_chunk( self, chunk: str, chunk_index: int, prompt: str ) -> List[int]: """Analyze a single large chunk to identify relevant smaller chunks.""" - small_chunks = chunk_content(chunk, self.SMALL_CHUNK_SIZE) + # Use smaller max_tokens to leave room for prompt and completion + small_chunks = chunk_content(chunk, self.SMALL_CHUNK_SIZE, max_tokens=40000) if not small_chunks: return [] - # Process small chunks in batches to stay within token limits - MAX_CHUNKS_PER_PROMPT = 5 # Adjust this based on actual token usage - results = [] - - for batch_start in range(0, len(small_chunks), MAX_CHUNKS_PER_PROMPT): - batch_end = min(batch_start + MAX_CHUNKS_PER_PROMPT, len(small_chunks)) - batch_chunks = small_chunks[batch_start:batch_end] - - analysis_prompt = ( - f"Below is part {batch_start//MAX_CHUNKS_PER_PROMPT + 1} of chunk {chunk_index + 1}, " - f"containing sub-chunks {batch_start + 1} to {batch_end} of the total {len(small_chunks)} sub-chunks.\n" - "Analyze which sub-chunks are relevant to answering the query.\n" - "Respond ONLY with comma-separated sub-chunk numbers (using the original full numbering).\n" - "Example response format: 1,4,7\n" - "If no sub-chunks are relevant, respond with: none\n\n" - f"Query: {prompt}\n\n" - "Sub-chunks:\n" - ) + analysis_prompt = ( + f"Below is chunk {chunk_index + 1} of a larger codebase, split into {len(small_chunks)} " + f"sub-chunks, followed by a user query.\n" + "Analyze which sub-chunks are relevant to answering the query.\n" + "Respond ONLY with comma-separated sub-chunk numbers (1-based indexing).\n" + "Example response format: 1,4,7\n\n" + f"Query: {prompt}\n\n" + "Sub-chunks:\n" + ) - for i, small_chunk in enumerate(batch_chunks, batch_start + 1): - analysis_prompt += f"\nSUB-CHUNK {i}:\n{small_chunk}\n" + for i, small_chunk in enumerate(small_chunks, 1): + analysis_prompt += f"\nSUB-CHUNK {i}:\n{small_chunk}\n" + try: + agent = Agent( + agent_name=self.agent_name, + user=self.user, + ApiClient=self.ApiClient, + ) + if "agent_name" in self.AGENT_SETTINGS: + del self.AGENT_SETTINGS["agent_name"] + if "user" in self.AGENT_SETTINGS: + del self.AGENT_SETTINGS["user"] + if "ApiClient" in self.AGENT_SETTINGS: + del self.AGENT_SETTINGS["ApiClient"] + agent.PROVIDER = Providers( + name=self.ANALYSIS_PROVIDER, + ApiClient=self.ApiClient, + agent_name=self.agent_name, + user=self.user, + **self.AGENT_SETTINGS, + ) try: - agent = Agent( - agent_name=self.agent_name, - user=self.user, - ApiClient=self.ApiClient, + result = await agent.inference(prompt=analysis_prompt) + except Exception as e: + logging.error( + f"Chunk analysis failed for chunk {chunk_index + 1}: {str(e)}" ) - if "agent_name" in self.AGENT_SETTINGS: - del self.AGENT_SETTINGS["agent_name"] - if "user" in self.AGENT_SETTINGS: - del self.AGENT_SETTINGS["user"] - if "ApiClient" in self.AGENT_SETTINGS: - del self.AGENT_SETTINGS["ApiClient"] agent.PROVIDER = Providers( - name=self.ANALYSIS_PROVIDER, + name="rotation", ApiClient=self.ApiClient, agent_name=self.agent_name, user=self.user, **self.AGENT_SETTINGS, ) - try: - result = await agent.inference(prompt=analysis_prompt) - except Exception as e: - logging.error( - f"Chunk analysis failed for batch {batch_start//MAX_CHUNKS_PER_PROMPT + 1} of chunk {chunk_index + 1}: {str(e)}" - ) - agent.PROVIDER = Providers( - name="rotation", - ApiClient=self.ApiClient, - agent_name=self.agent_name, - user=self.user, - **self.AGENT_SETTINGS, - ) - result = await agent.inference(prompt=analysis_prompt) + result = await agent.inference(prompt=analysis_prompt) - if result.strip().lower() != "none": - # Parse comma-separated numbers, convert to 0-based indexing - chunk_numbers = [int(n.strip()) - 1 for n in result.split(",")] - # Validate chunk numbers - valid_numbers = [ - n for n in chunk_numbers if 0 <= n < len(small_chunks) - ] - results.extend(valid_numbers) + # Parse comma-separated numbers, convert to 0-based indexing + chunk_numbers = [int(n.strip()) - 1 for n in result.split(",")] + # Validate chunk numbers + valid_numbers = [n for n in chunk_numbers if 0 <= n < len(small_chunks)] - except Exception as e: - logging.error( - f"Batch analysis failed for chunk {chunk_index + 1}, batch {batch_start//MAX_CHUNKS_PER_PROMPT + 1}: {str(e)}" + if not valid_numbers: + logging.warning( + f"No valid chunk numbers returned for chunk {chunk_index + 1}, using all sub-chunks" ) - # On complete failure, include all chunks from this batch - results.extend(range(batch_start, batch_end)) + return list(range(len(small_chunks))) - if not results: - logging.warning( - f"No valid chunk numbers returned for any batch in chunk {chunk_index + 1}, using all sub-chunks" + return valid_numbers + except Exception as e: + logging.error( + f"Chunk analysis failed for chunk {chunk_index + 1}: {str(e)}" ) - return list(range(len(small_chunks))) - - return sorted(set(results)) # Remove duplicates and sort + return list(range(len(small_chunks))) # Return all sub-chunks on failure async def _get_relevant_chunks(self, text: str, prompt: str) -> str: """Split text into large chunks and analyze them in parallel."""