Skip to content

Commit

Permalink
Merge branch 'upstream_main' into roberta
Browse files Browse the repository at this point in the history
  • Loading branch information
flaviabeo committed Nov 13, 2024
2 parents 366a992 + 0b8bb86 commit aed1216
Show file tree
Hide file tree
Showing 104 changed files with 3,453 additions and 783 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.neuron
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi

RUN python3 -m pip install -U \
'cmake>=3.26,<=3.30' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements-neuron.txt

ENV VLLM_TARGET_DEVICE neuron
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.ppc64le
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ RUN --mount=type=bind,source=.git,target=.git \
# These packages will be in rocketce eventually
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
'cmake>=3.26,<=3.30' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
torch==2.3.1 \
-r requirements-cpu.txt \
xformers uvloop==0.20.0
Expand Down
66 changes: 38 additions & 28 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor const& t,
}
}

/// each thread processes a block per query
__global__ void advance_step_flashinfer_kernel(
int num_threads, int num_seqs, int num_queries, int block_size,
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
Expand Down Expand Up @@ -134,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel(
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;

// Update paged_kv_indptr
if (idx == 0) {
paged_kv_indptr_ptr[idx] = 0;
}
if (idx < num_queries) {
int sum = 0;
for (int i = 0; i <= idx; ++i) {
Expand All @@ -146,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel(
}

__global__ void advance_step_flashinfer_indices_kernel(
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;
int row = idx / block_tables_stride;
int col = idx % block_tables_stride;

if (row < num_queries && col < block_table_bound_ptr[row]) {
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
block_tables_ptr[row * block_tables_stride + col];
// note: max_num_blocks_per_seq = block_tables.stride(0)
int tid = blockIdx.x * blockDim.x + threadIdx.x;

// when cuda graphs are enabled, paged_kv_indptr tensor
// has to be updated for the padded queries
// tid represents a query# for paged_kv_indptr tensor
if (num_queries < tid && tid <= num_seqs) {
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if (num_queries < row && row <= num_seqs) {
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];

// each thread processes a block_ptr in block_tables
// block_tables shape: [num_queries, max_num_blocks_per_seq]
// paged_kv_indices is flattened block_tables.
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
idx += (gridDim.x * blockDim.x)) {
// block_tables-row = paged_kv_indptr[queryNum]
int queryNum = idx / max_num_blocks_per_seq;
int col = idx % max_num_blocks_per_seq;
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
paged_kv_indices_ptr[indices_arr_idx] =
block_tables_ptr[block_tables_idx];
}
}
}

Expand Down Expand Up @@ -247,22 +263,16 @@ void advance_step_flashinfer(
int threads;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
if (logging) {
printf("launching kernel with %d blocks\n", blocks);
}

// TODO(will): support arbitrary block_tables stride
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
TORCH_CHECK(false,
"multi-step: not enough threads to map block_table to"
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
"of seqs,",
" increasing the block size or take smaller steps.",
" num_queries = ", num_queries,
" block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
int block_tables_stride = block_tables.stride(0);
TORCH_CHECK((blocks * threads > num_queries),
"multi-step: not enough threads to map to num_queries = ",
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
if (logging) {
printf("launching kernels with %d blocks and %d threads\n", blocks,
threads);
}

advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
Expand All @@ -281,7 +291,7 @@ void advance_step_flashinfer(
reinterpret_cast<int*>(block_table_bound.data_ptr()));

advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries,
num_seqs, num_queries,
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
Expand Down
4 changes: 3 additions & 1 deletion docs/source/_static/custom.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ document.addEventListener("DOMContentLoaded", function () {
script.setAttribute("version", "stable");
script.setAttribute("runllm-keyboard-shortcut", "Mod+j"); // cmd-j or ctrl-j to open the widget.
script.setAttribute("runllm-name", "vLLM");
script.setAttribute("runllm-position", "TOP_RIGHT");
script.setAttribute("runllm-position", "BOTTOM_RIGHT");
script.setAttribute("runllm-position-y", "20%");
script.setAttribute("runllm-position-x", "3%");
script.setAttribute("runllm-assistant-id", "207");

script.async = true;
Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/cpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Build from source
.. code-block:: console
$ pip install --upgrade pip
$ pip install cmake>=3.26,<=3.30 wheel packaging ninja "setuptools-scm>=8" numpy
$ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
- Finally, build and install vLLM CPU backend:
Expand Down
4 changes: 4 additions & 0 deletions docs/source/getting_started/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ Hangs loading a model from disk
If the model is large, it can take a long time to load it from disk. Pay attention to where you store the model. Some clusters have shared filesystems across nodes, e.g. a distributed filesystem or a network filesystem, which can be slow.
It'd be better to store the model in a local disk. Additionally, have a look at the CPU memory usage, when the model is too large it might take a lot of CPU memory, slowing down the operating system because it needs to frequently swap between disk and memory.

.. note::

To isolate the model downloading and loading issue, you can use the ``--load-format dummy`` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck.

Model is too large
----------------------------------------
If the model is too large to fit in a single GPU, you might want to `consider tensor parallelism <https://docs.vllm.ai/en/latest/serving/distributed_serving.html#distributed-inference-and-serving>`_ to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using `this example <https://docs.vllm.ai/en/latest/getting_started/examples/save_sharded_state.html>`_ . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/enabling_multimodal_inputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
3. Register maximum number of multi-modal tokens
------------------------------------------------

For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.

.. code-block:: diff
Expand Down
10 changes: 8 additions & 2 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ Text Generation
- Idefics3
- T + I
- :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc.
-
- ✅︎
-
* - :code:`InternVLChatModel`
- InternVL2
Expand Down Expand Up @@ -538,7 +538,7 @@ Text Generation
- ✅︎
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL
- T + I\ :sup:`E+` + V\ :sup:`+`
- T + I\ :sup:`E+` + V\ :sup:`E+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
- ✅︎
- ✅︎
Expand Down Expand Up @@ -584,6 +584,12 @@ Multimodal Embedding
- :code:`TIGER-Lab/VLM2Vec-Full`
- 🚧
- ✅︎
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL-based
- T + I
- :code:`MrLight/dse-qwen2-2b-mrl-v1`
-
- ✅︎

.. important::
Some model architectures support both generation and embedding tasks.
Expand Down
17 changes: 17 additions & 0 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,21 @@ Since the request schema is not defined by OpenAI client, we post a request to t
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model.

.. code-block:: bash
vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embedding \
--trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja
.. important::

Like with VLM2Vec, we have to explicitly pass ``--task embedding``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings,
which is handled by the jinja template.

.. important::

Also important, ``MrLight/dse-qwen2-2b-mrl-v1`` requires a placeholder image of the minimum image size for text query embeddings. See the full code
example below for details.

A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_.
123 changes: 105 additions & 18 deletions examples/openai_chat_embedding_client_for_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,120 @@
import argparse
import base64
import io

import requests
from PIL import Image

image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"

response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model":
"TIGER-Lab/VLM2Vec-Full",
"messages": [{

def vlm2vec():
response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model":
"TIGER-Lab/VLM2Vec-Full",
"messages": [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Represent the given image."
},
],
}],
"encoding_format":
"float",
},
)
response.raise_for_status()
response_json = response.json()

print("Embedding output:", response_json["data"][0]["embedding"])


def dse_qwen2_vl(inp: dict):
# Embedding an Image
if inp["dtype"] == "image":
messages = [{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": inp["image_url"],
}
}, {
"type": "text",
"text": "What is shown in this image?"
}]
}]
# Embedding a Text Query
else:
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
# of the minimum input size
buffer = io.BytesIO()
image_placeholder = Image.new("RGB", (56, 56))
image_placeholder.save(buffer, "png")
buffer.seek(0)
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8')
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
"url": f"data:image/jpeg;base64,{image_placeholder}",
}
},
{
"type": "text",
"text": "Represent the given image."
"text": f"Query: {inp['content']}"
},
],
}],
"encoding_format":
"float",
},
)
response.raise_for_status()
response_json = response.json()

print("Embedding output:", response_json["data"][0]["embedding"])
]
}]

response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model": "MrLight/dse-qwen2-2b-mrl-v1",
"messages": messages,
"encoding_format": "float",
},
)
response.raise_for_status()
response_json = response.json()

print("Embedding output:", response_json["data"][0]["embedding"])


if __name__ == '__main__':
parser = argparse.ArgumentParser(
"Script to call a specified VLM through the API. Make sure to serve "
"the model with --task embedding before running this.")
parser.add_argument("model",
type=str,
choices=["vlm2vec", "dse_qwen2_vl"],
required=True,
help="Which model to call.")
args = parser.parse_args()

if args.model == "vlm2vec":
vlm2vec()
elif args.model == "dse_qwen2_vl":
dse_qwen2_vl({
"dtye": "image",
"image_url": image_url,
})
dse_qwen2_vl({
"dtype": "text",
"content": "What is the weather like today?",
})
7 changes: 7 additions & 0 deletions examples/template_dse_qwen2_vl.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{% raw %}<|im_start|>system
You are a helpful assistant.<|im_end|>
{% endraw %}{% endif %}<|im_start|>{{ message['role'] }}{% raw %}
{% endraw %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>{% raw %}
{% endraw %}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>{% raw %}
{% endraw %}{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant{% raw %}
{% endraw %}{% endif %}<|endoftext|>
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[build-system]
# Should be mirrored in requirements-build.txt
requires = [
"cmake>=3.26,<=3.30",
"cmake>=3.26",
"ninja",
"packaging",
"setuptools>=61",
Expand Down
2 changes: 1 addition & 1 deletion requirements-build.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Should be mirrored in pyproject.toml
cmake>=3.26,<=3.30
cmake>=3.26
ninja
packaging
setuptools>=61
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL.
compressed-tensors == 0.7.1 # required for compressed-tensors
compressed-tensors == 0.8.0 # required for compressed-tensors
2 changes: 1 addition & 1 deletion requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
-r requirements-common.txt

# Dependencies for TPU
cmake>=3.26,<=3.30
cmake>=3.26
ninja
packaging
setuptools-scm>=8
Expand Down
Loading

0 comments on commit aed1216

Please sign in to comment.