Skip to content

Commit aafd733

Browse files
authored
feat(llm): Add custom HTTP headers support to ChatNVIDIA provider (#1461)
* feat(llm): Add custom HTTP headers support to ChatNVIDIA provider Add custom HTTP headers support to the ChatNVIDIA class patch, enabling users to pass custom headers (authentication tokens, request IDs, billing information, etc.) with all requests to NVIDIA AI endpoints. Implementation Approach - Added custom_headers optional field to ChatNVIDIA class with Pydantic v2 compatibility - Implemented runtime method wrapping that intercepts _client.get_req() and _client.get_req_stream() to merge custom headers with existing headers - Included automatic version detection to ensure compatibility with langchain-nvidia-ai-endpoints >= 0.3.0, with clear error messages for older versions - Works with both synchronous invoke() and streaming requests, fully compatible with VLM (Vision Language Models)
1 parent f7b5daf commit aafd733

File tree

2 files changed

+506
-5
lines changed

2 files changed

+506
-5
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,22 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
import logging
1718
from functools import wraps
18-
from typing import Any, List, Optional
19+
from typing import Any, Dict, List, Optional
1920

2021
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
2122
from langchain_core.language_models.chat_models import generate_from_stream
2223
from langchain_core.messages import BaseMessage
2324
from langchain_core.outputs import ChatResult
2425
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
25-
from pydantic.v1 import Field
26+
from pydantic import Field
2627

27-
log = logging.getLogger(__name__)
28+
log = logging.getLogger(__name__) # pragma: no cover
2829

2930

30-
def stream_decorator(func):
31+
def stream_decorator(func): # pragma: no cover
3132
@wraps(func)
3233
def wrapper(
3334
self,
@@ -51,10 +52,52 @@ def wrapper(
5152

5253
# NOTE: this needs to have the same name as the original class,
5354
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
54-
class ChatNVIDIA(ChatNVIDIAOriginal):
55+
class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover
5556
streaming: bool = Field(
5657
default=False, description="Whether to use streaming or not"
5758
)
59+
custom_headers: Optional[Dict[str, str]] = Field(
60+
default=None, description="Custom HTTP headers to send with requests"
61+
)
62+
63+
def __init__(self, **kwargs: Any):
64+
super().__init__(**kwargs)
65+
if self.custom_headers:
66+
custom_headers_error = (
67+
"custom_headers requires langchain-nvidia-ai-endpoints >= 0.3.0. "
68+
"Your version does not support the required client structure or "
69+
"extra_headers parameter. Please upgrade: "
70+
"pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0"
71+
)
72+
if not hasattr(self._client, "get_req"):
73+
raise RuntimeError(custom_headers_error)
74+
75+
sig = inspect.signature(self._client.get_req)
76+
if "extra_headers" not in sig.parameters:
77+
raise RuntimeError(custom_headers_error)
78+
79+
self._wrap_client_methods()
80+
81+
def _wrap_client_methods(self):
82+
original_get_req = self._client.get_req
83+
original_get_req_stream = self._client.get_req_stream
84+
85+
def wrapped_get_req(payload: dict = None, extra_headers: dict = None):
86+
payload = payload or {}
87+
extra_headers = extra_headers or {}
88+
merged_headers = {**extra_headers, **self.custom_headers}
89+
return original_get_req(payload=payload, extra_headers=merged_headers)
90+
91+
def wrapped_get_req_stream(payload: dict = None, extra_headers: dict = None):
92+
payload = payload or {}
93+
extra_headers = extra_headers or {}
94+
merged_headers = {**extra_headers, **self.custom_headers}
95+
return original_get_req_stream(
96+
payload=payload, extra_headers=merged_headers
97+
)
98+
99+
object.__setattr__(self._client, "get_req", wrapped_get_req)
100+
object.__setattr__(self._client, "get_req_stream", wrapped_get_req_stream)
58101

59102
@stream_decorator
60103
def _generate(

0 commit comments

Comments
 (0)