Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Adopting vllm to external KV store Vineyard #1

Merged
merged 2 commits into from
Oct 1, 2024

Conversation

happyandslow
Copy link
Collaborator

Adopting vllm to use external KV services for KV cache (vineyard v6d)
This PR adopt vllm to vllm integration provided by the vineyard team: sighingnow@d347dab#diff-c89ac25bd066e936e80260d21be63c7d2379cfedc371a9ff288fb5ba02ae1350

@happyandslow
Copy link
Collaborator Author

happyandslow commented Jul 19, 2024

Steps to set up vineyard and vllm for testing (partial credit to @haiyang from his internal doc):

Prepare v6d

If using shared memory cache, the following line needs to be commented out before build (see more detail below), for hacking purposes

sudo apt-get update
sudo apt-get install -y ca-certificates \
                   cmake \
                   doxygen \
                   libboost-all-dev \
                   libcurl4-openssl-dev \
                   libgflags-dev \
                   libgoogle-glog-dev \
                   libgrpc-dev \
                   libgrpc++-dev \
                   libmpich-dev \
                   libprotobuf-dev \
                   libssl-dev \
                   libunwind-dev \
                   libz-dev \
                   protobuf-compiler-grpc \
                   python3-pip \
                   wget \
                   ninja-build

sudo rm -rf /lib/x86_64-linux-gnu/libprotoc.a \
            /lib/x86_64-linux-gnu/libprotobuf.a \
            /lib/x86_64-linux-gnu/libprotobuf-lite.a \
            /lib/x86_64-linux-gnu/libprotobuf.so.23 \
            /lib/x86_64-linux-gnu/libprotobuf.so.23.0.4

sudo ldconfig

# we have to build arrow from source to use the system-wide protobuf
cd ~
git clone https://github.com/apache/arrow.git
cd arrow/cpp && git checkout apache-arrow-16.1.0 && mkdir build-release && cd build-release
cmake --preset ninja-release-python -DCMAKE_INSTALL_PREFIX=/usr/ -DProtobuf_PROTOC_LIBRARY=/lib/x86_64-linux-gnu/libprotoc.so.32 ..
cmake --build .
sudo ninja install

sudo pip3 install cython
cd ~/arrow/python
sudo python3 setup.py install

cd ~
git clone https://github.com/v6d-io/v6d
cd v6d  && git submodule update --init --recursive
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release \
         -DBUILD_SHARED_LIBS=ON \
         -DUSE_STATIC_BOOST_LIBS=OFF \
         -DBUILD_VINEYARD_SERVER=ON \
         -DBUILD_VINEYARD_CLIENT=OFF \
         -DBUILD_VINEYARD_PYTHON_BINDINGS=ON \
         -DBUILD_VINEYARD_PYPI_PACKAGES=OFF \
         -DBUILD_VINEYARD_LLM_CACHE=ON \
         -DBUILD_VINEYARD_BASIC=OFF \
         -DBUILD_VINEYARD_GRAPH=OFF \
         -DBUILD_VINEYARD_IO=OFF \
         -DBUILD_VINEYARD_HOSSEINMOEIN_DATAFRAME=OFF \
         -DBUILD_VINEYARD_TESTS=ON \
         -DBUILD_VINEYARD_TESTS_ALL=OFF \
         -DBUILD_VINEYARD_PROFILING=OFF
make -j
make vineyard_llm_python -j
sudo make install

sudo pip3 install cython
cd ~/v6d
sudo python3 setup.py install
sudo python3 setup_llm.py install

pip3 install google-benchmark

To start etcd

ETCD_VER=v3.4.33

# choose either URL
GOOGLE_URL=https://storage.googleapis.com/etcd
GITHUB_URL=https://github.com/etcd-io/etcd/releases/download
DOWNLOAD_URL=${GOOGLE_URL}

rm -f /tmp/etcd-${ETCD_VER}-linux-amd64.tar.gz
rm -rf /tmp/etcd-download-test && mkdir -p /tmp/etcd-download-test

curl -L ${DOWNLOAD_URL}/${ETCD_VER}/etcd-${ETCD_VER}-linux-amd64.tar.gz -o /tmp/etcd-${ETCD_VER}-linux-amd64.tar.gz
tar xzvf /tmp/etcd-${ETCD_VER}-linux-amd64.tar.gz -C /tmp/etcd-download-test --strip-components=1
rm -f /tmp/etcd-${ETCD_VER}-linux-amd64.tar.gz

/tmp/etcd-download-test/etcd --version
/tmp/etcd-download-test/etcdctl version

# start a local etcd server
/tmp/etcd-download-test/etcd 1>/dev/null 2>&1 &

# clear
# ETCDCTL_API=3 /tmp/etcd-download-test/etcdctl del "" --from-key=true

Build and run vllm

# start a vineyard server
~/v6d/build/bin/vineyardd --socket /tmp/vineyard_test.sock 1>out.log 2>&1 &

#build and install vllm 
cd ~ && git clone https://github.com/aibrix/vllm.git
cd vllm && git checkout lexu/vineyard-adptation 
pip3 install -e .

#upgrade pyarrow
python -m pip install pyarrow --upgrade
# build vineyard vllm
cd ~/v6d && sudo python3 setup.py install && sudo python3 setup_llm.py install

export VLLM_USE_VINEYARD_CACHE=1
export VLLM_USE_FLASH_ATTN_DECODING=1

Check out two ways of using cache here

If using file config cache

export VINEYARD_LLM_CACHE_FILESYSTEM=1
export VINEYARD_LLM_CACHE_FILESYSTEM_CHUNK_SIZE=16
export VINEYARD_LLM_CACHE_FILESYSTEM_HASH_CHUNK_SIZE=2
export VINEYARD_LLM_CACHE_FILESYSTEM_ROOT="/tmp/vineyard/llm_cache"

If using shared memory cache (hacking, WIP)

Make sure this line is comment out before you built v6d.

then

export VINEYARD_LLM_CACHE_SHARED_MEMORY=1
export VINEYARD_LLM_CACHE_SHARED_MEMORY_SOCKET="/tmp/vineyard_test.sock"
export VINEYARD_LLM_CACHE_SHARED_BLOCK_SIZE=5
export VINEYARD_LLM_CACHE_SHARED_MEMORY_SYNC_INTERVAL=3

Run vllm

#start vllm
cd ~/vllm
python3 -m vllm.entrypoints.openai.api_server --model=facebook/opt-125m --enable-chunked-prefill

# export VLLM_LOGGING_LEVEL=DEBUG for debugging purposes

To test locally, run:

curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
    "model": "facebook/opt-125m",
    "prompt": "San Francisco is a",
    "max_tokens": 7,
    "temperature": 0
}'

@happyandslow
Copy link
Collaborator Author

Problems encountered cherry-picking commit sighingnow@d347dab#diff-c89ac25bd066e936e80260d21be63c7d2379cfedc371a9ff288fb5ba02ae1350 to the latest branch:

  1. In model runner,
    execute_model used to take seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], as input argument,
    which is no longer used in the latest version. I have moved metadata list to the output of the prepare_input_tensors

  2. Vineyard LLM Cache:

  • All torch.distributed function seem to produce an error Pasted Graphic I have commented out all torch.distributed function for single node testing purposes
  • The CacheConfig argument is missing from the original implementation. I added FileConfig which works for now. But adding VineyardCacheConfig produces the following error:
Pasted Graphic 2

@@ -90,6 +90,12 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0

slot_mapping: Optional[torch.Tensor] = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can not find reference here. Seems not directly related to vineyard changes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_group_metadata_list should be the only field being added. The rest of the fields are ported from older versioned VLLM and I'll remove those.

@happyandslow
Copy link
Collaborator Author

happyandslow commented Jul 23, 2024

Comment out null checks in v6d: https://github.com/v6d-io/v6d/blob/ebe8f077e3d3780a27d49238c501854b6b8e29df/modules/llm-cache/ds/kv_cache_block.cc#L163 allows code to run with VineyardCacheConfig

However, running the following request sequence lead to segmentation fault on vllm client
image

Screen shot of the fault

image

@happyandslow
Copy link
Collaborator Author

Updated instruction above -- using shared memory cache could be done through the following

export VINEYARD_LLM_CACHE_SHARED_MEMORY=1
export VINEYARD_LLM_CACHE_SHARED_MEMORY_SOCKET="/tmp/vineyard_test.sock"
export VINEYARD_LLM_CACHE_SHARED_BLOCK_SIZE=5
export VINEYARD_LLM_CACHE_SHARED_MEMORY_SYNC_INTERVAL=3

@Jeffwan
Copy link

Jeffwan commented Aug 7, 2024

Let's create a separate branch and ping the commit id. rest PRs should be submitted against that branch. Otherwise, this PR will grow endlessly and not easy for collaboration. @happyandslow @DwyaneShi

@Jeffwan
Copy link

Jeffwan commented Aug 7, 2024

feature branch could use feat/xxxx as the feature branch branch

@happyandslow happyandslow changed the base branch from main to feat/distributed-kv-cache August 12, 2024 23:27
@happyandslow happyandslow force-pushed the lexu/vineyard-adptation branch from f888485 to 70f523c Compare August 14, 2024 18:48
@happyandslow
Copy link
Collaborator Author

Using v1 configuration and benchmark code from link results in the following segmentation fault while running on v6d (commit link):

Fatal Python error: Segmentation fault

Thread 0x00007f5a497fe640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f5a49fff640 (most recent call first):
  File "/usr/lib/python3.10/selectors.py", line 416 in select
  File "/usr/lib/python3.10/subprocess.py", line 2021 in _communicate
  File "/usr/lib/python3.10/subprocess.py", line 1154 in communicate
  File "/home/ubuntu/.local/lib/python3.10/site-packages/cpuinfo/cpuinfo.py", line 2742 in get_cpu_info_json
  File "/home/ubuntu/.local/lib/python3.10/site-packages/cpuinfo/cpuinfo.py", line 2759 in get_cpu_info
  File "/home/ubuntu/lexu/vllm/vllm/usage/usage_lib.py", line 164 in _report_usage_once
  File "/home/ubuntu/lexu/vllm/vllm/usage/usage_lib.py", line 147 in _report_usage_worker
  File "/usr/lib/python3.10/threading.py", line 953 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f92673ff640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f9266bfe640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/home/ubuntu/.local/lib/python3.10/site-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Current thread 0x00007f9418a831c0 (most recent call first):
  File "/usr/local/lib/python3.10/dist-packages/vineyard-0.23.2-py3.10-linux-x86_64.egg/vineyard/llm/cache.py", line 312 in update
  File "/home/ubuntu/lexu/vllm/vllm/worker/vineyard_llm_cache.py", line 365 in update_seq_kv_caches
  File "/home/ubuntu/lexu/vllm/vllm/worker/vineyard_llm_cache.py", line 405 in update_kv_caches
  File "/home/ubuntu/lexu/vllm/vllm/worker/worker_base.py", line 280 in execute_model
  File "/home/ubuntu/lexu/vllm/vllm/executor/gpu_executor.py", line 110 in execute_model
  File "/home/ubuntu/lexu/vllm/vllm/engine/llm_engine.py", line 923 in step
  File "/home/ubuntu/lexu/vllm/vllm/entrypoints/llm.py", line 569 in _run_engine
  File "/home/ubuntu/lexu/vllm/vllm/entrypoints/llm.py", line 316 in generate
  File "/home/ubuntu/lexu/vllm/vllm/utils.py", line 880 in inner
  File "/home/ubuntu/lexu/vllm/benchmarks/benchmark_prefix_caching.py", line 55 in test_prefix
  File "/home/ubuntu/lexu/vllm/benchmarks/benchmark_prefix_caching.py", line 158 in main
  File "/home/ubuntu/lexu/vllm/benchmarks/benchmark_prefix_caching.py", line 208 in <module>

Extension modules: _brotli, simplejson._speedups, charset_normalizer.md, yaml._yaml, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, psutil._psutil_linux, psutil._psutil_posix, sentencepiece._sentencepiece, msgpack._cmsgpack, google.protobuf.pyext._message, setproctitle, uvloop.loop, ray._raylet, multidict._multidict, yarl._quoting_c, aiohttp._helpers, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket, frozenlist._frozenlist, regex._regex, lz4._version, lz4.frame._frame, PIL._imaging, zmq.backend.cython.context, zmq.backend.cython.message, zmq.backend.cython.socket, zmq.backend.cython._device, zmq.backend.cython._poll, zmq.backend.cython._proxy_steerable, zmq.backend.cython._version, zmq.backend.cython.error, zmq.backend.cython.utils, pyarrow.lib, scipy._lib._ccallback_c, scipy.sparse._sparsetools, scipy.sparse._csparsetools, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.strptime, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.tslib, pandas._libs.lib, pandas._libs.hashing, pandas._libs.ops, numexpr.interpreter, bottleneck.reduce, bottleneck.nonreduce, bottleneck.nonreduce_axis, bottleneck.move, pandas._libs.arrays, pandas._libs.index, pandas._libs.join, pandas._libs.sparse, pyarrow._compute, pandas._libs.reduction, pandas._libs.indexing, pandas._libs.internals, pandas._libs.writers, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.parsers, pandas._libs.json, pandas._libs.testing (total: 110)
Segmentation fault (core dumped)

@happyandslow
Copy link
Collaborator Author

Further investigation using gdb (along with prompt that triggers the error)

test prompt write again, we have 6 Entity Relationship Model:
1. Restaurant (restaurant\_id, name, location, contact\_details, branch\_id)
2. Queue (queue\_id, restaurant\_id, status, expected\_waiting\_time, period\_id)
3. Customer (customer\_id, queue\_id, name, contact\_details, time\_joined)
4. Table (table\_id, restaurant\_id, table\_number, seats, status)
5. Branch (branch\_id, name, location, contact\_details)
6. Period (period\_id, start\_time, end\_time)

Please write in English language.
Processed prompts:   0%|                                                                           | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]INFO 08-16 18:35:22 vineyard_llm_cache.py:135] prefetch_seq_kv_caches context_len 0 query_context_len 0 query_token_size 168

Thread 1 "pt_main_thread" received signal SIGSEGV, Segmentation fault.
0x00007ffff7ca3dee in _int_malloc (av=av@entry=0x7ffff7e19c80 <main_arena>, bytes=bytes@entry=8) at ./malloc/malloc.c:3903
3903    ./malloc/malloc.c: No such file or directory.
(gdb) backtrace
#0  0x00007ffff7ca3dee in _int_malloc (av=av@entry=0x7ffff7e19c80 <main_arena>, bytes=bytes@entry=8) at ./malloc/malloc.c:3903
#1  0x00007ffff7ca5139 in __GI___libc_malloc (bytes=8) at ./malloc/malloc.c:3329
#2  0x00007ffff68ae98c in operator new(unsigned long) () from /lib/x86_64-linux-gnu/libstdc++.so.6
#3  0x00007ffeee01ac7e in __gnu_cxx::new_allocator<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*>::allocate (
    this=0x7fffff7ff328, __n=<optimized out>) at /usr/include/c++/11/ext/new_allocator.h:103
#4  std::allocator_traits<std::allocator<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*> >::allocate (__n=<optimized out>, 
    __a=...) at /usr/include/c++/11/bits/alloc_traits.h:464
#5  std::_Vector_base<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*, std::allocator<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*> >::_M_allocate (__n=<optimized out>, this=<optimized out>) at /usr/include/c++/11/bits/stl_vector.h:346
#6  std::vector<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*, std::allocator<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*> >::_M_realloc_insert<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*> (
    this=this@entry=0x7fffff7ff328, __position=0x7ffff7e19ce0 <main_arena+96>) at /usr/include/c++/11/bits/vector.tcc:440
#7  0x00007ffeee01fd55 in std::vector<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*, std::allocator<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*> >::emplace_back<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*>
    (this=0x7fffff7ff328) at /usr/include/c++/11/bits/vector.tcc:121
#8  std::vector<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*, std::allocator<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >*> >::push_back (__x=@0x7fffff7ff170: 0x7fffff7ff4f0, this=0x7fffff7ff328) at /usr/include/c++/11/bits/stl_vector.h:1204
#9  nlohmann::detail::json_sax_dom_parser<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > > >::start_object (
    len=18446744073709551615, this=0x7fffff7ff320) at /home/ubuntu/lexu/v6d/thirdparty/nlohmann-json/single_include/nlohmann/json.hpp:5510
#10 nlohmann::detail::parser<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >, nlohmann::detail::iterator_input_adapter<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >::sax_parse_internal<nlohmann::detail::json_sax_dom_parser<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > > > > (this=0x7fffff7ff5e0, sax=sax@entry=0x7fffff7ff320)
    at /home/ubuntu/lexu/v6d/thirdparty/nlohmann-json/single_include/nlohmann/json.hpp:10416
#11 0x00007ffeee020a3d in nlohmann::detail::parser<nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >, nlohmann::detail::iterator_input_adapter<__gnu_cxx::__normal_iterator<char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >::parse (this=0x7fffff7ff5e0, 
    strict=<optimized out>, result=...) at /home/ubuntu/lexu/v6d/thirdparty/nlohmann-json/single_include/nlohmann/json.hpp:10344
#12 0x00007ffeee00bfbe in nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >::parse<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&, std::function<bool (int, nlohmann::detail::parse_event_t, nlohmann::basic_json<std::map, std::vector, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, long, unsigned long, double, std::allocator, nlohmann::adl_serializer, std::vector<unsigned char, std::allocator<unsigned char> > >&)>, bool, bool) (ignore_comments=false, allow_exceptions=true, cb=..., 
    i="{\"created\":{\"data_offset\":12437499072,\"data_size\":7680,\"is_gpu\":false,\"is_owner\":true,\"is_sealed\":false,\"map_size\":238540029960,\"object_id\":9223387945290489083,\"pointer\":139914922237120,\"store_fd\":18}"...) at /home/ubuntu/lexu/v6d/thirdparty/nlohmann-json/single_include/nlohmann/json.hpp:23287
#13 vineyard::ClientBase::doRead (this=this@entry=0x55556c951620, root=...) at /home/ubuntu/lexu/v6d/src/client/client_base.cc:573
#14 0x00007ffeedff2852 in vineyard::Client::CreateBuffer (this=0x55556c951620, size=7680, id=@0x7fffff7ff8c8: 18446744073709551615, payload=..., 
    buffer=std::shared_ptr<vineyard::MutableBuffer> (empty) = {...}) at /home/ubuntu/lexu/v6d/src/client/client.cc:622
#15 0x00007ffeedff2f4c in vineyard::Client::CreateBlob (this=0x55556c951620, size=7680, blob=std::unique_ptr<vineyard::BlobWriter> = {...})
    at /home/ubuntu/lexu/v6d/src/client/client.cc:235
#16 0x00007ffeea8c6aed in vineyard::KVTensorBuilder::KVTensorBuilder(vineyard::Client&, std::vector<long, std::allocator<long> > const&) ()
   from /usr/local/lib/libvineyard_llm_cache.so
#17 0x00007ffeea8c231d in vineyard::KVCacheBlockBuilder::KVCacheBlockBuilder(vineyard::Client&, int, int, int) () from /usr/local/lib/libvineyard_llm_cache.so
--Type <RET> for more, q to quit, c to continue without paging--
#18 0x00007ffeea8a83de in vineyard::KVCacheBuilder::Split(vineyard::KVCacheBlockBuilder*, std::vector<std::shared_ptr<vineyard::NodeData>, std::allocator<std::shared_ptr<vineyard::NodeData> > >, vineyard::KVCacheBlockBuilder*&) () from /usr/local/lib/libvineyard_llm_cache.so
#19 0x00007ffeea8aa0cf in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#20 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#21 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#22 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#23 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#24 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#25 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#26 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#27 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#28 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#29 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#30 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#31 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#32 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#33 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#34 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#35 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#36 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#37 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#38 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#39 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#40 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#41 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#42 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#43 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#44 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#45 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::allocator<std::pair<vineyard::LLMKV, vineyard::LLMKV> > > const&) () from /usr/local/lib/libvineyard_llm_cache.so
#46 0x00007ffeea8aa51b in vineyard::KVCacheBuilder::Update(std::vector<int, std::allocator<int> > const&, int, std::vector<std::pair<vineyard::LLMKV, vineyard::LLMKV>, std::--Type <RET> for more, q to quit, c to continue without paging--

@happyandslow happyandslow force-pushed the lexu/vineyard-adptation branch 3 times, most recently from ac2a3ac to 5586db7 Compare August 26, 2024 18:17
# to ensure the tensor model parallel group is initialized.
self.vineyard_llm_cache = None

set_cpu_offload_max_bytes(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this duplicate snippet

if envs.VLLM_USE_VINEYARD_CACHE:
if not self.scheduler_config.chunked_prefill_enabled:
logger.warn("Vineyard LLM cache is not enabled, requires chunked prefill")
elif not envs.VLLM_USE_FLASH_ATTN_DECODING:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is required, why do we create an env there?

model_config=self.model_config,
parallel_config=self.parallel_config,
kv_cache_dtype=self.kv_cache_dtype,
torch_dtype=get_kv_cache_torch_dtype(self.kv_cache_dtype,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_dtype can be determined by kv_cache_dtype and self.model_config is it a better idea to make it internal viariable?

)
if self.vineyard_llm_cache:
logger.info("Using Vineyard LLM cache")
else:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, it just throws warning message. what if I enabled cache but disable chunk prefill, should we throw error and avoid it coming into the next steps here? if that case, you do not need that many self.vineyard_llm_cache checks

intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
if self.vineyard_llm_cache and kv_caches[0] is not None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems the method is profile_run, do we need to update the kv cache here?

torch_dtype=torch_dtype,
)

def prefetch_seq_kv_caches(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will double check the logics here latr

Based on another version of vllm: sighingnow@d347dab

Cherry-pick from commit d347dab

Signed-off-by: Tao He <[email protected]>
(cherry picked from commit 1545f6bf7edcd667e305d3fbcadd913066f04747)

resolving vllm update diff

temporarily comment out torch.distributed for single node env

add VineyardCacheConfig with https://github.com/v6d-io/v6d/blob/ebe8f077e3d3780a27d49238c501854b6b8e29df/modules/llm-cache/ds/kv_cache_block.cc#L163 commented out; cache_ops fix

remove CacheConfig from argument (configure through ENV)

v6d: fix integration w/ v1 APIs

Signed-off-by: Haiyang Shi <[email protected]>

Change model_runner to latest version

cherry pick model_runner from d347dab source sighingnow@d347dab

fix reshape_and_cache_flash argument

add cache prefetch/update to work_base

clean up

Fix after rebase to 029c71d

remove tensor copy from cache managed address to pin memory

clean up
@happyandslow happyandslow force-pushed the lexu/vineyard-adptation branch from 3fd0048 to 45554d1 Compare September 16, 2024 16:40
Copy link

@Jeffwan Jeffwan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's merge this one

@Jeffwan Jeffwan merged commit 7cdac48 into feat/distributed-kv-cache Oct 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants