Skip to content

Commit

Permalink
Merge pull request #18 from gkumbhat/detector_dispatcher_deco_2
Browse files Browse the repository at this point in the history
Detector dispatcher decorator
  • Loading branch information
gkumbhat authored Feb 3, 2025
2 parents a845837 + ae50dd8 commit a435f31
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 28 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ classifiers = [
]

dependencies = [
"vllm>=0.7.0"
"vllm @ git+https://github.com/vllm-project/[email protected] ; sys_platform == 'darwin'",
"vllm>=0.7.1 ; sys_platform != 'darwin'",
]

[project.optional-dependencies]
Expand Down
25 changes: 13 additions & 12 deletions tests/generative_detectors/test_granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DetectionChatMessageParam,
DetectionResponse,
)
from vllm_detector_adapter.utils import DetectorType

MODEL_NAME = "ibm-granite/granite-guardian" # Example granite-guardian model
CHAT_TEMPLATE = "Dummy chat template for testing {}"
Expand Down Expand Up @@ -177,8 +178,8 @@ def test_preprocess_chat_request_with_detector_params(granite_guardian_detection
],
detector_params=detector_params,
)
processed_request = granite_guardian_detection_instance.preprocess_chat_request(
initial_request
processed_request = granite_guardian_detection_instance.preprocess_request(
initial_request, fn_type=DetectorType.TEXT_CHAT
)
assert type(processed_request) == ChatDetectionRequest
# Processed request should not have these extra params
Expand Down Expand Up @@ -214,8 +215,8 @@ def test_request_to_chat_completion_request_prompt_analysis(granite_guardian_det
},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ChatCompletionRequest
Expand Down Expand Up @@ -247,8 +248,8 @@ def test_request_to_chat_completion_request_reponse_analysis(
},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ChatCompletionRequest
Expand All @@ -274,8 +275,8 @@ def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detect
detector_params={"n": 2, "chat_template_kwargs": {}}, # no guardian config
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ErrorResponse
Expand All @@ -294,8 +295,8 @@ def test_request_to_chat_completion_request_empty_guardian_config(
detector_params={"n": 2, "chat_template_kwargs": {"guardian_config": {}}},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ErrorResponse
Expand All @@ -317,8 +318,8 @@ def test_request_to_chat_completion_request_unsupported_risk_name(
},
)
chat_request = (
granite_guardian_detection_instance.request_to_chat_completion_request(
context_request, MODEL_NAME
granite_guardian_detection_instance._request_to_chat_completion_request(
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
)
assert type(chat_request) == ErrorResponse
Expand Down
81 changes: 81 additions & 0 deletions vllm_detector_adapter/detector_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Standard
import functools

# global list to store all the registered functions with
# their types and qualified name
global_fn_list = dict()


def detector_dispatcher(types=None):
"""Decorator to dispatch to processing function based on type of the detector.
This decorator allows us to reuse same function name for different types of detectors.
For example, the same function name can be used for text chat and context analysis
detectors. These decorated functions for these detectors will have different arguments
and implementation but they share the same function name.
NOTE: At the time of invoking these decorated function, the user needs to specify the type
of the detector using fn_type argument.
CAUTION: Since this decorator allow re-use of the name, one must take care of using different types
for testing different functions.
Args:
types (list): Type of the detector this function applies to.
args: Positional arguments passed to the processing function.
kwargs: Keyword arguments passed to the processing function.
Examples
--------
@detector_dispatcher(types=["foo"])
def f(x):
pass
# Decorator can take multiple types as well
@detector_dispatcher(types=["bar", "baz"])
def f(x):
pass
When calling these functions, one can specify the type of the detector as follows:
f(x, fn_type="foo")
f(x, fn_type="bar")
"""
global global_fn_list

if not types:
raise ValueError("Must specify types.")

def decorator(func):
fn_name = func.__qualname__

if fn_name not in global_fn_list:
# Associate each function with its type to create a dictionary of form:
# {"fn_name": {type1: function, type2: function}
# NOTE: "function" here are really function pointers
global_fn_list[fn_name] = {t: func for t in types}
elif fn_name in global_fn_list and (types & global_fn_list[fn_name].keys()):
# Error out if the types function with same type declaration exist in the global
# list already
raise ValueError("Function already registered with the same types.")
else:
# Add the function to the global list with corresponding type
global_fn_list[fn_name] |= {t: func for t in types}

@functools.wraps(func)
def wrapper(*args, fn_type=None, **kwargs):
fn_name = func.__qualname__

if not fn_type:
raise ValueError("Must specify fn_type.")

if fn_type not in global_fn_list[fn_name].keys():
raise ValueError("Invalid fn_type.")

# Grab the function using its fully qualified name and the specified type
# and then call it
return global_fn_list[fn_name][fn_type](*args, **kwargs)

return wrapper

return decorator
14 changes: 10 additions & 4 deletions vllm_detector_adapter/generative_detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
import torch

# Local
from vllm_detector_adapter.detector_dispatcher import detector_dispatcher
from vllm_detector_adapter.logging import init_logger
from vllm_detector_adapter.protocol import (
ChatDetectionRequest,
ContextAnalysisRequest,
DetectionResponse,
)
from vllm_detector_adapter.utils import DetectorType

logger = init_logger(__name__)

Expand Down Expand Up @@ -80,13 +82,17 @@ def apply_output_template(

##### Chat request processing functions ####################################

def apply_task_template_to_chat(
# Usage of detector_dispatcher allows same function name to be called for different types of
# detectors with different arguments and implementation.
@detector_dispatcher(types=[DetectorType.TEXT_CHAT])
def apply_task_template(
self, request: ChatDetectionRequest
) -> Union[ChatDetectionRequest, ErrorResponse]:
"""Apply task template on the chat request"""
return request

def preprocess_chat_request(
@detector_dispatcher(types=[DetectorType.TEXT_CHAT])
def preprocess_request(
self, request: ChatDetectionRequest
) -> Union[ChatDetectionRequest, ErrorResponse]:
"""Preprocess chat request"""
Expand Down Expand Up @@ -185,14 +191,14 @@ async def chat(

# Apply task template if it exists
if self.task_template:
request = self.apply_task_template_to_chat(request)
request = self.apply_task_template(request, fn_type=DetectorType.TEXT_CHAT)
if isinstance(request, ErrorResponse):
# Propagate any request problems that will not allow
# task template to be applied
return request

# Optionally make model-dependent adjustments for the request
request = self.preprocess_chat_request(request)
request = self.preprocess_request(request, fn_type=DetectorType.TEXT_CHAT)

chat_completion_request = request.to_chat_completion_request(model_name)
if isinstance(chat_completion_request, ErrorResponse):
Expand Down
34 changes: 23 additions & 11 deletions vllm_detector_adapter/generative_detectors/granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse

# Local
from vllm_detector_adapter.detector_dispatcher import detector_dispatcher
from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase
from vllm_detector_adapter.logging import init_logger
from vllm_detector_adapter.protocol import (
ChatDetectionRequest,
ContextAnalysisRequest,
DetectionResponse,
)
from vllm_detector_adapter.utils import DetectorType

logger = init_logger(__name__)

Expand All @@ -33,7 +35,9 @@ class GraniteGuardian(ChatCompletionDetectionBase):
PROMPT_CONTEXT_ANALYSIS_RISKS = ["context_relevance"]
RESPONSE_CONTEXT_ANALYSIS_RISKS = ["groundedness"]

def preprocess(
##### Private / Internal functions ###################################################

def __preprocess(
self, request: Union[ChatDetectionRequest, ContextAnalysisRequest]
) -> Union[ChatDetectionRequest, ContextAnalysisRequest, ErrorResponse]:
"""Granite guardian specific parameter updates for risk name and risk definition"""
Expand All @@ -59,13 +63,10 @@ def preprocess(

return request

def preprocess_chat_request(
self, request: ChatDetectionRequest
) -> Union[ChatDetectionRequest, ErrorResponse]:
"""Granite guardian chat request preprocess is just detector parameter updates"""
return self.preprocess(request)

def request_to_chat_completion_request(
# Decorating this function to make it cleaner for future iterations of this function
# to support other types of detectors
@detector_dispatcher(types=[DetectorType.TEXT_CONTEXT_DOC])
def _request_to_chat_completion_request(
self, request: ContextAnalysisRequest, model_name: str
) -> Union[ChatCompletionRequest, ErrorResponse]:
NO_RISK_NAME_MESSAGE = "No risk_name for context analysis"
Expand Down Expand Up @@ -141,6 +142,17 @@ def request_to_chat_completion_request(
code=HTTPStatus.BAD_REQUEST.value,
)

##### General request / response processing functions ##################

# Used detector_dispatcher decorator to allow for the same function to be called
# for different types of detectors with different request types etc.
@detector_dispatcher(types=[DetectorType.TEXT_CHAT])
def preprocess_request(
self, request: ChatDetectionRequest
) -> Union[ChatDetectionRequest, ErrorResponse]:
"""Granite guardian chat request preprocess is just detector parameter updates"""
return self.__preprocess(request)

async def context_analyze(
self,
request: ContextAnalysisRequest,
Expand All @@ -152,13 +164,13 @@ async def context_analyze(

# Task template not applied for context analysis at this time
# Make model-dependent adjustments for the request
request = self.preprocess(request)
request = self.__preprocess(request)

# Since particular chat messages are dependent on Granite Guardian risk definitions,
# the processing is done here rather than in a separate, general to_chat_completion_request
# for all context analysis requests.
chat_completion_request = self.request_to_chat_completion_request(
request, model_name
chat_completion_request = self._request_to_chat_completion_request(
request, model_name, fn_type=DetectorType.TEXT_CONTEXT_DOC
)
if isinstance(chat_completion_request, ErrorResponse):
# Propagate any request problems
Expand Down
11 changes: 11 additions & 0 deletions vllm_detector_adapter/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Standard
from enum import Enum, auto


class DetectorType(Enum):
"""Enum to represent different types of detectors"""

TEXT_CONTENT = auto()
TEXT_GENERATION = auto()
TEXT_CHAT = auto()
TEXT_CONTEXT_DOC = auto()

0 comments on commit a435f31

Please sign in to comment.