Skip to content

Commit 39d74ad

Browse files
authored
feat(llm): Add automatic provider inference for LangChain LLMs (#1460)
1 parent a2f17bc commit 39d74ad

File tree

2 files changed

+190
-1
lines changed

2 files changed

+190
-1
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,70 @@ def __init__(self, inner_exception: Any):
4646
self.inner_exception = inner_exception
4747

4848

49+
def _infer_provider_from_module(llm: BaseLanguageModel) -> Optional[str]:
50+
"""Infer provider name from the LLM's module path.
51+
52+
This function extracts the provider name from LangChain package naming conventions:
53+
- langchain_openai -> openai
54+
- langchain_anthropic -> anthropic
55+
- langchain_google_genai -> google_genai
56+
- langchain_nvidia_ai_endpoints -> nvidia_ai_endpoints
57+
- langchain_community.chat_models.ollama -> ollama
58+
59+
For patched/wrapped classes, checks base classes as well.
60+
61+
Args:
62+
llm: The LLM instance
63+
64+
Returns:
65+
The inferred provider name, or None if it cannot be determined
66+
"""
67+
module = type(llm).__module__
68+
69+
if module.startswith("langchain_"):
70+
package = module.split(".")[0]
71+
provider = package.replace("langchain_", "")
72+
73+
if provider == "community":
74+
parts = module.split(".")
75+
if len(parts) >= 3:
76+
provider = parts[-1]
77+
return provider
78+
else:
79+
return provider
80+
81+
for base_class in type(llm).__mro__[1:]:
82+
base_module = base_class.__module__
83+
if base_module.startswith("langchain_"):
84+
package = base_module.split(".")[0]
85+
provider = package.replace("langchain_", "")
86+
87+
if provider == "community":
88+
parts = base_module.split(".")
89+
if len(parts) >= 3:
90+
provider = parts[-1]
91+
return provider
92+
else:
93+
return provider
94+
95+
return None
96+
97+
98+
def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]:
99+
"""Get the provider name for an LLM instance by inferring from module path.
100+
101+
This function extracts the provider name from LangChain package naming conventions.
102+
See _infer_provider_from_module for details on the inference logic.
103+
104+
Args:
105+
llm: The LLM instance
106+
107+
Returns:
108+
The provider name if it can be inferred, None otherwise
109+
"""
110+
return _infer_provider_from_module(llm)
111+
112+
49113
def _infer_model_name(llm: BaseLanguageModel):
50114
"""Helper to infer the model name based from an LLM instance.
51115
@@ -126,7 +190,7 @@ def _setup_llm_call_info(
126190
llm_call_info_var.set(llm_call_info)
127191

128192
llm_call_info.llm_model_name = model_name or _infer_model_name(llm)
129-
llm_call_info.llm_provider_name = model_provider
193+
llm_call_info.llm_provider_name = model_provider or _infer_provider_from_module(llm)
130194

131195

132196
def _prepare_callbacks(

tests/test_actions_llm_utils.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from nemoguardrails.actions.llm.utils import _infer_provider_from_module
17+
18+
19+
class MockOpenAILLM:
20+
__module__ = "langchain_openai.chat_models"
21+
22+
23+
class MockAnthropicLLM:
24+
__module__ = "langchain_anthropic.chat_models"
25+
26+
27+
class MockNVIDIALLM:
28+
__module__ = "langchain_nvidia_ai_endpoints.chat_models"
29+
30+
31+
class MockCommunityOllama:
32+
__module__ = "langchain_community.chat_models.ollama"
33+
34+
35+
class MockUnknownLLM:
36+
__module__ = "some_custom_package.models"
37+
38+
39+
class MockNVIDIAOriginal:
40+
__module__ = "langchain_nvidia_ai_endpoints.chat_models"
41+
42+
43+
class MockPatchedNVIDIA(MockNVIDIAOriginal):
44+
__module__ = "nemoguardrails.llm.providers._langchain_nvidia_ai_endpoints_patch"
45+
46+
47+
def test_infer_provider_openai():
48+
llm = MockOpenAILLM()
49+
provider = _infer_provider_from_module(llm)
50+
assert provider == "openai"
51+
52+
53+
def test_infer_provider_anthropic():
54+
llm = MockAnthropicLLM()
55+
provider = _infer_provider_from_module(llm)
56+
assert provider == "anthropic"
57+
58+
59+
def test_infer_provider_nvidia_ai_endpoints():
60+
llm = MockNVIDIALLM()
61+
provider = _infer_provider_from_module(llm)
62+
assert provider == "nvidia_ai_endpoints"
63+
64+
65+
def test_infer_provider_community_ollama():
66+
llm = MockCommunityOllama()
67+
provider = _infer_provider_from_module(llm)
68+
assert provider == "ollama"
69+
70+
71+
def test_infer_provider_unknown():
72+
llm = MockUnknownLLM()
73+
provider = _infer_provider_from_module(llm)
74+
assert provider is None
75+
76+
77+
def test_infer_provider_from_patched_class():
78+
llm = MockPatchedNVIDIA()
79+
provider = _infer_provider_from_module(llm)
80+
assert provider == "nvidia_ai_endpoints"
81+
82+
83+
def test_infer_provider_checks_base_classes():
84+
class BaseOpenAI:
85+
__module__ = "langchain_openai.chat_models"
86+
87+
class CustomWrapper(BaseOpenAI):
88+
__module__ = "my_custom_wrapper.llms"
89+
90+
llm = CustomWrapper()
91+
provider = _infer_provider_from_module(llm)
92+
assert provider == "openai"
93+
94+
95+
def test_infer_provider_multiple_inheritance():
96+
class BaseNVIDIA:
97+
__module__ = "langchain_nvidia_ai_endpoints.chat_models"
98+
99+
class Mixin:
100+
__module__ = "some_mixin.utils"
101+
102+
class MultipleInheritance(Mixin, BaseNVIDIA):
103+
__module__ = "custom_package.models"
104+
105+
llm = MultipleInheritance()
106+
provider = _infer_provider_from_module(llm)
107+
assert provider == "nvidia_ai_endpoints"
108+
109+
110+
def test_infer_provider_deeply_nested_inheritance():
111+
class Original:
112+
__module__ = "langchain_anthropic.chat_models"
113+
114+
class Wrapper1(Original):
115+
__module__ = "wrapper1.models"
116+
117+
class Wrapper2(Wrapper1):
118+
__module__ = "wrapper2.models"
119+
120+
class Wrapper3(Wrapper2):
121+
__module__ = "wrapper3.models"
122+
123+
llm = Wrapper3()
124+
provider = _infer_provider_from_module(llm)
125+
assert provider == "anthropic"

0 commit comments

Comments
 (0)