diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index c2a87fe06b418..f3cfd575ae8b8 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -19,6 +19,9 @@ logger = logging.getLogger(__name__) +# OpenAI API limits +MAX_TOKENS_PER_REQUEST = 300000 # OpenAI's max tokens per embedding request + def _process_batched_chunked_embeddings( num_texts: int, @@ -540,14 +543,38 @@ def _get_len_safe_embeddings( client_kwargs = {**self._invocation_params, **kwargs} _iter, tokens, indices = self._tokenize(texts, _chunk_size) batched_embeddings: list[list[float]] = [] - for i in _iter: + # Calculate actual token counts for each chunk + token_counts = [len(t) if isinstance(t, list) else len(t.split()) for t in tokens] + + # Process in batches respecting the token limit + i = 0 + while i < len(tokens): + # Determine how many chunks we can include in this batch + batch_token_count = 0 + batch_end = i + + for j in range(i, min(i + _chunk_size, len(tokens))): + chunk_tokens = token_counts[j] + # Check if adding this chunk would exceed the limit + if batch_token_count + chunk_tokens > MAX_TOKENS_PER_REQUEST: + if batch_end == i: + # Single chunk exceeds limit - handle it anyway + batch_end = j + 1 + break + batch_token_count += chunk_tokens + batch_end = j + 1 + + # Make API call with this batch + batch_tokens = tokens[i:batch_end] response = self.client.create( - input=tokens[i : i + _chunk_size], **client_kwargs + input=batch_tokens, **client_kwargs ) if not isinstance(response, dict): response = response.model_dump() batched_embeddings.extend(r["embedding"] for r in response["data"]) + i = batch_end + embeddings = _process_batched_chunked_embeddings( len(texts), tokens, batched_embeddings, indices, self.skip_empty ) @@ -594,15 +621,38 @@ async def _aget_len_safe_embeddings( None, self._tokenize, texts, _chunk_size ) batched_embeddings: list[list[float]] = [] - for i in range(0, len(tokens), _chunk_size): + # Calculate actual token counts for each chunk + token_counts = [len(t) if isinstance(t, list) else len(t.split()) for t in tokens] + + # Process in batches respecting the token limit + i = 0 + while i < len(tokens): + # Determine how many chunks we can include in this batch + batch_token_count = 0 + batch_end = i + + for j in range(i, min(i + _chunk_size, len(tokens))): + chunk_tokens = token_counts[j] + # Check if adding this chunk would exceed the limit + if batch_token_count + chunk_tokens > MAX_TOKENS_PER_REQUEST: + if batch_end == i: + # Single chunk exceeds limit - handle it anyway + batch_end = j + 1 + break + batch_token_count += chunk_tokens + batch_end = j + 1 + + # Make API call with this batch + batch_tokens = tokens[i:batch_end] response = await self.async_client.create( - input=tokens[i : i + _chunk_size], **client_kwargs + input=batch_tokens, **client_kwargs ) - if not isinstance(response, dict): response = response.model_dump() batched_embeddings.extend(r["embedding"] for r in response["data"]) + i = batch_end + embeddings = _process_batched_chunked_embeddings( len(texts), tokens, batched_embeddings, indices, self.skip_empty ) diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py index f87dc181a37b8..0a133ebdcc6ff 100644 --- a/libs/partners/openai/tests/unit_tests/embeddings/test_base.py +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_base.py @@ -96,3 +96,55 @@ async def test_embed_with_kwargs_async() -> None: mock_create.assert_any_call(input=texts, **client_kwargs) assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + +def test_embeddings_respects_token_limit() -> None: + """Test that embeddings respect the 300k token per request limit. + + This tests the fix for issue #31227 where large batches of embeddings + would exceed OpenAI's token limit. + """ + from langchain_openai import OpenAIEmbeddings + + # Create embeddings instance + embeddings = OpenAIEmbeddings( + model="text-embedding-ada-002", + openai_api_key="test-key" + ) + + # Mock the client to track API calls + from unittest.mock import MagicMock, Mock + + call_counts = [] + original_create = embeddings.client.create + + def mock_create(input, **kwargs): + # Track how many tokens in this call + if isinstance(input, list): + total_tokens = sum(len(t) if isinstance(t, list) else len(t.split()) for t in input) + call_counts.append(total_tokens) + # Verify this call doesn't exceed limit + assert total_tokens <= 300000, f"Batch exceeded token limit: {total_tokens} tokens" + + # Return mock response + mock_response = Mock() + mock_response.model_dump.return_value = { + "data": [{"embedding": [0.1] * 1536} for _ in range(len(input) if isinstance(input, list) else 1)] + } + return mock_response + + embeddings.client.create = mock_create + + # Create a scenario that would exceed 300k tokens in a single batch + # with default chunk_size=1000 + # Simulate 500 texts with ~1000 tokens each = 500k tokens total + large_texts = ["word " * 1000 for _ in range(500)] + + # This should not raise an error anymore + result = embeddings.embed_documents(large_texts) + + # Verify we made multiple API calls to respect the limit + assert len(call_counts) > 1, "Should have split into multiple batches" + + # Verify each call respected the limit + for count in call_counts: + assert count <= 300000, f"Batch exceeded limit: {count}"