Skip to content

Commit d05c4cd

Browse files
committed
working
1 parent d92c077 commit d05c4cd

File tree

6 files changed

+43
-37
lines changed

6 files changed

+43
-37
lines changed

llumnix/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import vllm
15-
from vllm import *
14+
# import vllm
15+
# from vllm import *
1616

1717
from llumnix.server_info import ServerInfo
1818
from llumnix.entrypoints.utils import (launch_ray_cluster,
@@ -39,4 +39,4 @@
3939
"QueueType",
4040
]
4141

42-
__all__.extend(getattr(vllm, "__all__", []))
42+
# __all__.extend(getattr(vllm, "__all__", []))

llumnix/backends/bladellm/migration_backend.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@ def __init__(self, worker_address: str, migration_config: MigrationConfig, state
3636
self.worker_address = worker_address
3737
self.state_manager = state_manager
3838
self.num_migration_cache_blocks = migration_config.migration_cache_blocks
39-
40-
migration_cache_key_shape = state_manager._kv_cache[0].shape
41-
migration_cache_key_shape[0] = migration_config.migration_cache_blocks
42-
migration_cache_value_shape = state_manager._kv_cache[1].shape
43-
migration_cache_value_shape[0] = migration_config.migration_cache_blocks
39+
migration_cache_key_shape = list([len(state_manager._kv_cache[0])]) + list(state_manager._kv_cache[0][0].shape)
40+
migration_cache_key_shape[1] = migration_config.migration_cache_blocks
41+
migration_cache_value_shape = list([len(state_manager._kv_cache[1])]) + list(state_manager._kv_cache[1][0].shape)
42+
migration_cache_value_shape[1] = migration_config.migration_cache_blocks
4443
if state_manager.dtype in NUMPY_SUPPORTED_DTYPES:
4544
self.migration_cache_dtype = state_manager.dtype
4645
else:

llumnix/backends/bladellm/proto/migration_worker_pb2_grpc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import grpc
44

55
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
6-
import migration_worker_pb2 as migration__worker__pb2
6+
import llumnix.backends.bladellm.proto.migration_worker_pb2 as migration__worker__pb2
77

88

99
class MigrationWorkerStub(object):

llumnix/backends/bladellm/worker.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@
3030

3131
logger = init_logger(__name__)
3232

33-
class MigrationWorker(migration_worker_pb2_grpc.LlumnixWorkerServicer, RemoteWorker):
33+
class MigrationWorker(migration_worker_pb2_grpc.MigrationWorkerServicer, RemoteWorker):
3434
def __init__(self, instance_id: str, worker_addr: str, migration_config: MigrationConfig,
3535
*args, **kwargs) -> None:
3636
super().__init__(*args, **kwargs)
37-
38-
torch.cuda.set_device(args.device)
37+
torch.cuda.set_device(args[1].device)
3938
self.instance_id = instance_id
4039
self.migration_backend = get_migration_backend(worker_addr, migration_config, self._engine._state_manager)
4140

@@ -89,7 +88,7 @@ async def worker_server(rank: int, args: ServingArgs, instance_id: str, migratio
8988
else f"unix://{args.worker_socket_path}.{instance_id}.{rank}"
9089
)
9190

92-
worker = MigrationWorker(rank, args, instance_id, listen_addr, migration_config)
91+
worker = MigrationWorker(instance_id, listen_addr, migration_config, rank, args)
9392

9493
server = grpc.aio.server(migration_thread_pool=ThreadPoolExecutor(max_workers=1))
9594
bladellm_pb2_grpc.add_WorkerServicer_to_server(worker, server)

tests/unit_test/backends/bladellm/proto/mock_migration_worker_pb2_grpc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import grpc
44

55
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
6-
import mock_migration_worker_pb2 as mock__migration__worker__pb2
6+
import tests.unit_test.backends.bladellm.proto.mock_migration_worker_pb2 as mock__migration__worker__pb2
77

88

99
class MockMigrationWorkerStub(object):

tests/unit_test/backends/bladellm/test_migration_backend.py

+31-23
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import time
1415
from typing import List
1516
import random
1617
import asyncio
1718
import torch
1819
import pytest
1920
import grpc
20-
from multiprocessing import Process
21+
from multiprocessing import Process, set_start_method
2122
from concurrent.futures import ThreadPoolExecutor
2223
from google.protobuf import empty_pb2
2324
import numpy as np
@@ -30,10 +31,12 @@
3031
from llumnix.backends.bladellm.proto import migration_worker_pb2_grpc, migration_worker_pb2
3132
from llumnix.internal_config import MigrationConfig
3233
from llumnix.utils import random_uuid
34+
from llumnix.arg_utils import EngineManagerArgs
3335

34-
from .proto import mock_migration_worker_pb2_grpc, mock_migration_worker_pb2
36+
from tests.unit_test.backends.bladellm.proto import mock_migration_worker_pb2_grpc, mock_migration_worker_pb2
3537

36-
class MockMigrationWorker(mock_migration_worker_pb2_grpc.MockMigrationWorkerServicer, MigrationWorker):
38+
# class MockMigrationWorker(mock_migration_worker_pb2_grpc.MockMigrationWorkerServicer, MigrationWorker):
39+
class MockMigrationWorker(MigrationWorker):
3740
def get_kv_cache_meta(self, request, context):
3841
return mock_migration_worker_pb2.KvCacheMeta(
3942
shape=self.migration_backend.dummy_key_cache.shape,
@@ -56,63 +59,68 @@ def get_gpu_cache(self, request, context):
5659
torch.cuda.synchronize()
5760
return mock_migration_worker_pb2.KvCacheData(key=key, value=value)
5861

59-
async def worker_main(rank: int, args: ServingArgs, instance_id: str, migration_config: MigrationConfig):
60-
asyncio.run(launch_worker(rank, args, instance_id, migration_config))
61-
62-
async def launch_worker(rank: int, args: ServingArgs, instance_id: str, migration_config: MigrationConfig):
63-
listen_addr = f"unix://{args.worker_socket_path}.{instance_id}.{rank}"
64-
worker = MockMigrationWorker(rank, args, instance_id, listen_addr, migration_config)
62+
def worker_main(listen_addr: str, rank: int, args: ServingArgs, instance_id: str, migration_config: MigrationConfig):
63+
asyncio.run(launch_worker(listen_addr, rank, args, instance_id, migration_config))
6564

65+
async def launch_worker(listen_addr: str, rank: int, args: ServingArgs, instance_id: str, migration_config: MigrationConfig):
66+
worker = MockMigrationWorker(instance_id, listen_addr, migration_config, rank, args)
6667
server = grpc.aio.server(migration_thread_pool=ThreadPoolExecutor(max_workers=1))
6768
bladellm_pb2_grpc.add_WorkerServicer_to_server(worker, server)
6869
migration_worker_pb2_grpc.add_MigrationWorkerServicer_to_server(worker, server)
69-
mock_migration_worker_pb2_grpc.add_MockMigrationWorkerServicer_to_server(worker, server)
70+
# mock_migration_worker_pb2_grpc.add_MockMigrationWorkerServicer_to_server(worker, server)
7071
server.add_insecure_port(listen_addr)
72+
print(f"Starting server on {listen_addr}")
7173
await server.start()
74+
print(f"Server running on {listen_addr}")
7275
await server.wait_for_termination()
7376

7477
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.")
7578
@pytest.mark.parametrize("backend", ['grpc'])
7679
def test_migrate_cache(backend):
7780
worker_count = 2
7881
worker_args = ServingArgs(
79-
max_gpu_memory_utilization=0.1,
82+
max_gpu_memory_utilization=0.5,
8083
block_size=3,
8184
load_model_options=LoadModelOptions(
82-
model='/mnt/self-hosted/model/Qwen2.5-7B', attn_cls="paged", disable_cuda_graph=True
85+
model='facebook/opt-125m', attn_cls="paged", disable_cuda_graph=True
8386
)
8487
)
85-
worker_socket_addr = []
86-
migration_config = MigrationConfig(migration_backend=backend, migration_cache_blocks=8)
88+
89+
worker_socket_addrs = []
90+
migration_config = EngineManagerArgs(migration_backend=backend, migration_cache_blocks=8).create_migration_config()
8791

92+
set_start_method("spawn", force=True)
8893
backends: List[Process] = []
8994
for i in range(worker_count):
9095
instance_id = random_uuid()
91-
p = Process(target=worker_main, args=(0, worker_args, instance_id, migration_config))
92-
worker_socket_addr.append(f"unix://{worker_args.worker_socket_path}.{instance_id}.{0}")
96+
worker_args.device=f"cuda:{i}"
97+
worker_socket_addrs.append(f"localhost:{1234+i}")
98+
99+
p = Process(target=worker_main, args=(worker_socket_addrs[-1], i, worker_args, instance_id, migration_config))
93100
p.start()
94101
backends.append(p)
95102

96103
# assert all(ray.get([worker0.execute_method.remote('rebuild_migration_backend',
97104
# instance_rank=instance_rank, group_name=group_name),
98105
# worker1.execute_method.remote('rebuild_migration_backend',
99106
# instance_rank=instance_rank, group_name=group_name)]))
107+
time.sleep(5)
100108

101109
for i in range(worker_count):
102-
with grpc.insecure_channel(worker_socket_addr[i]) as channel:
110+
with grpc.insecure_channel(worker_socket_addrs[i]) as channel:
103111
stub = migration_worker_pb2_grpc.MigrationWorkerStub(channel)
104112
responce = stub.warmup(empty_pb2.Empty())
105113
assert responce.ok
106114

107-
with grpc.insecure_channel(worker_socket_addr[0]) as channel:
115+
with grpc.insecure_channel(worker_socket_addrs[0]) as channel:
108116
stub = mock_migration_worker_pb2_grpc.MockMigrationWorkerStub(channel)
109117
responce = stub.get_kv_cache_meta(empty_pb2.Empty())
110118
kv_cache_shape, dtype, total_gpu_blocks = responce.shape, responce.dtype, responce.num_gpu_blocks
111-
119+
112120
dummy_key_data = torch.randn(size=kv_cache_shape, dtype=dtype)
113121
dummy_value_data = torch.randn(size=kv_cache_shape, dtype=dtype)
114122

115-
with grpc.insecure_channel(worker_socket_addr[0]) as channel:
123+
with grpc.insecure_channel(worker_socket_addrs[0]) as channel:
116124
stub = mock_migration_worker_pb2_grpc.MockMigrationWorkerStub(channel)
117125
responce = stub.set_gpu_cache(mock_migration_worker_pb2.KvCacheData(
118126
key=dummy_key_data.numpy().tobytes(),
@@ -126,15 +134,15 @@ def test_migrate_cache(backend):
126134
dst_blocks = list(range(total_gpu_blocks))
127135
random.shuffle(dst_blocks)
128136

129-
with grpc.insecure_channel(worker_socket_addr[1]) as channel:
137+
with grpc.insecure_channel(worker_socket_addrs[1]) as channel:
130138
src_stub = migration_worker_pb2_grpc.MigrationWorkerStub(channel)
131139
src_stub.migrate_cache(migration_worker_pb2.MigrateRequest(
132-
src_handlers=[None, worker_socket_addr[0]],
140+
src_handlers=[None, worker_socket_addrs[0]],
133141
src_blocks=list(range(total_gpu_blocks)),
134142
dst_blocks=dst_blocks,
135143
))
136144

137-
with grpc.insecure_channel(worker_socket_addr[1]) as channel:
145+
with grpc.insecure_channel(worker_socket_addrs[1]) as channel:
138146
stub = mock_migration_worker_pb2_grpc.MockMigrationWorkerStub(channel)
139147
responce = stub.get_gpu_cache(empty_pb2.Empty())
140148
worker1_key_data = torch.from_numpy(np.frombuffer(responce.key, dtype=dtype))

0 commit comments

Comments
 (0)