Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated inline vllm inference provider #880

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

frreiss
Copy link
Contributor

@frreiss frreiss commented Jan 26, 2025

What does this PR do?

This PR updates the inline vLLM inference provider in several significant ways:

  • Models are now attached at run time to instances of the provider via the .../models API instead of hard-coding the model's full name into the provider's YAML configuration.
  • The provider supports models that are not Meta Llama models. Any model that vLLM supports can be loaded by passing Huggingface coordinates in the "provider_model_id" field. Custom fine-tuned versions of Meta Llama models can be loaded by specifying a path on local disk in the "provider_model_id".
  • To implement full chat completions support, including tool calling and constrained decoding, the provider now routes the chat_completions API to a captive (i.e. called directly in-process, not via HTTPS) instance of vLLM's OpenAI-compatible server .
  • The logprobs parameter and completions API are also working.

Test Plan

Existing tests in llama_stack/providers/tests/inference/test_text_inference.py have good coverage of the new functionality. These tests can be invoked as follows:

cd llama-stack && pytest \
    -vvv \
    llama_stack/providers/tests/inference/test_text_inference.py \
    --providers inference=vllm \
    --inference-model meta-llama/Llama-3.2-3B-Instruct
====================================== test session starts ======================================
platform linux -- Python 3.12.8, pytest-8.3.4, pluggy-1.5.0 -- /mnt/datadisk1/freiss/llama/env/bin/python3.12
cachedir: .pytest_cache
metadata: {'Python': '3.12.8', 'Platform': 'Linux-6.8.0-1016-ibm-x86_64-with-glibc2.39', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'anyio': '4.8.0', 'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.2'}, 'JAVA_HOME': '/usr/lib/jvm/java-8-openjdk-amd64'}
rootdir: /mnt/datadisk1/freiss/llama/llama-stack
configfile: pyproject.toml
plugins: anyio-4.8.0, html-4.1.1, metadata-3.1.1, asyncio-0.25.2
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None
collected 9 items                                                                               

llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-vllm] PASSED [ 11%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-vllm] PASSED [ 22%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_logprobs[-vllm] PASSED [ 33%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-vllm] PASSED [ 44%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-vllm] PASSED [ 55%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-vllm] PASSED [ 66%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-vllm] PASSED [ 77%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-vllm] PASSED [ 88%]
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-vllm] PASSED [100%]

=========================== 9 passed, 13 warnings in 97.18s (0:01:37) ===========================

Sources

Before submitting

  • Ran pre-commit to handle lint / formatting issues.
  • Read the contributor guideline,
    Pull Request section?
  • Updated relevant documentation.
  • Wrote necessary unit or integration tests.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 26, 2025
Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful PR, thank you!

I have a few comments inline.

"messages": converted_messages,
"tools": converted_tools,
"tool_choice": converted_tool_choice,
"stream": request.stream,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a bit more idiomatic python to write

request_options = {
  "model": ...,
  **sampling_options,
  **guided_decoding_options,
  **logprob_options
}

# OpenAI's APIs don't know about.
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
# For now, translate repetition penalties into a format that vLLM's broken
# API will handle correctly. Two wrongs make a right...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)

):
converted_tool_choice = "auto"

# TODO: Figure out what to do with the tool_prompt_format argument.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is rather important actually when the underlying model is a llama model. See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/utils/inference/prompt_adapter.py#L286-L297 for how we try to adapt the tool formatting to the underlying llama model. each llama model is a special snowflake :/

my recommendation therefore is to treat llama models specially when routing to vLLM. when you detect a model is a llama model (we use metadata.llama_model from the model registration info elsewhere for this purpose), you should route it to the "raw" completions API and keep control of prompt formatting within the Stack. otherwise, you can invoke the path you have implemented.

let me know your thoughts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can certainly take that route for now.

It would be good in the longer term to have the different Llama model tool formats fully integrated into the vLLM engine. That way systems that use a vLLM-only inference stack will see consistent results with Llama Stack.

import vllm.sampling_params

############################################################################
# llama_models imports go here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think these comments are very useful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove these comments.

############################################################################
# vLLM imports go here
#
# We deep-import the names that don't conflict with Llama Stack names
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

return None

# Llama Stack currently implements fewer types of constrained
# decoding than vLLM does. Translate the types that exist and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏

vllm_top_p = 1.0
vllm_temperature = 0.0

# vLLM allows top-p and top-k at the same time.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha what is the implementation in that case? I wasn't aware this combination could make sense!

guided_decoding_backend="lm-format-enforcer",
)
###########################################################################
# METHODS INHERITED FROM UNDOCUMENTED IMPLICIT MYSTERY BASE CLASS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hehehe, legit feedback and taken! I will fix this by making a ProviderBase and documenting the lifecycle properly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating this comment accordingly.

if self.resolved_model_id is not None:
if resolved_model_id != self.resolved_model_id:
raise ValueError(
f"Attempted to serve two LLMs (ids "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe your line width is rather small and isn't really idiomatic of how the rest of the Stack code looks like. Could I ask you to make this a bit wider?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, can do. From the llama_reference provider code, it looks like the standard line width is 100. Is that correct?

Callback that is called when the server removes an inference endpoint
from an inference provider.

The semantics of this callback are not clear. How should model_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's fair. we will add appropriate documentation.

model_id is the same ID you got in register_model().

the behavioral semantics are left up to the provider -- it basically means the Stack will no longer recognize this model and if the provider wants to do any resource deallocation (e.g., maybe they could send an API call to unwind a deployment, etc.) this callback is the place to initiate it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that clarification. I'll update the pydoc comment and add some code here to remove the selected model ID and shut down the connection to vLLM if no more IDs are being served.

# first one.
if len(vllm_result.choices) == 0:
raise ValueError(
"Don't know how to convert response object without any " "responses"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double quotes need to be escaped here

Comment on lines +111 to +119
def _log(msg: str, level: str):
if _BYPASS_LOGGING:
time_str = datetime.datetime.now().strftime("%H:%M:%S")
print(f"{time_str}: {msg}")
match level:
case "info":
logger.info(msg)
case "debug":
logger.debug(msg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit hacky. We should probably fix and improve the logging separately

None if sampling_params.max_tokens == 0 else sampling_params.max_tokens
),
# Assume that vLLM's default stop token will work
# stop_token_ids=[tokenizer.eos_token_id],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants