11
11
# See the License for the specific language governing permissions and
12
12
# limitations under the License.
13
13
14
+ import time
14
15
from typing import List
15
16
import random
16
17
import asyncio
17
18
import torch
18
19
import pytest
19
20
import grpc
20
- from multiprocessing import Process
21
+ from multiprocessing import Process , set_start_method
21
22
from concurrent .futures import ThreadPoolExecutor
22
23
from google .protobuf import empty_pb2
23
24
import numpy as np
30
31
from llumnix .backends .bladellm .proto import migration_worker_pb2_grpc , migration_worker_pb2
31
32
from llumnix .internal_config import MigrationConfig
32
33
from llumnix .utils import random_uuid
34
+ from llumnix .arg_utils import EngineManagerArgs
33
35
34
- from .proto import mock_migration_worker_pb2_grpc , mock_migration_worker_pb2
36
+ from tests . unit_test . backends . bladellm .proto import mock_migration_worker_pb2_grpc , mock_migration_worker_pb2
35
37
36
- class MockMigrationWorker (mock_migration_worker_pb2_grpc .MockMigrationWorkerServicer , MigrationWorker ):
38
+ # class MockMigrationWorker(mock_migration_worker_pb2_grpc.MockMigrationWorkerServicer, MigrationWorker):
39
+ class MockMigrationWorker (MigrationWorker ):
37
40
def get_kv_cache_meta (self , request , context ):
38
41
return mock_migration_worker_pb2 .KvCacheMeta (
39
42
shape = self .migration_backend .dummy_key_cache .shape ,
@@ -56,63 +59,68 @@ def get_gpu_cache(self, request, context):
56
59
torch .cuda .synchronize ()
57
60
return mock_migration_worker_pb2 .KvCacheData (key = key , value = value )
58
61
59
- async def worker_main (rank : int , args : ServingArgs , instance_id : str , migration_config : MigrationConfig ):
60
- asyncio .run (launch_worker (rank , args , instance_id , migration_config ))
61
-
62
- async def launch_worker (rank : int , args : ServingArgs , instance_id : str , migration_config : MigrationConfig ):
63
- listen_addr = f"unix://{ args .worker_socket_path } .{ instance_id } .{ rank } "
64
- worker = MockMigrationWorker (rank , args , instance_id , listen_addr , migration_config )
62
+ def worker_main (listen_addr : str , rank : int , args : ServingArgs , instance_id : str , migration_config : MigrationConfig ):
63
+ asyncio .run (launch_worker (listen_addr , rank , args , instance_id , migration_config ))
65
64
65
+ async def launch_worker (listen_addr : str , rank : int , args : ServingArgs , instance_id : str , migration_config : MigrationConfig ):
66
+ worker = MockMigrationWorker (instance_id , listen_addr , migration_config , rank , args )
66
67
server = grpc .aio .server (migration_thread_pool = ThreadPoolExecutor (max_workers = 1 ))
67
68
bladellm_pb2_grpc .add_WorkerServicer_to_server (worker , server )
68
69
migration_worker_pb2_grpc .add_MigrationWorkerServicer_to_server (worker , server )
69
- mock_migration_worker_pb2_grpc .add_MockMigrationWorkerServicer_to_server (worker , server )
70
+ # mock_migration_worker_pb2_grpc.add_MockMigrationWorkerServicer_to_server(worker, server)
70
71
server .add_insecure_port (listen_addr )
72
+ print (f"Starting server on { listen_addr } " )
71
73
await server .start ()
74
+ print (f"Server running on { listen_addr } " )
72
75
await server .wait_for_termination ()
73
76
74
77
@pytest .mark .skipif (torch .cuda .device_count () < 2 , reason = "Need at least 2 GPU to run the test." )
75
78
@pytest .mark .parametrize ("backend" , ['grpc' ])
76
79
def test_migrate_cache (backend ):
77
80
worker_count = 2
78
81
worker_args = ServingArgs (
79
- max_gpu_memory_utilization = 0.1 ,
82
+ max_gpu_memory_utilization = 0.5 ,
80
83
block_size = 3 ,
81
84
load_model_options = LoadModelOptions (
82
- model = '/mnt/self-hosted/model/Qwen2.5-7B ' , attn_cls = "paged" , disable_cuda_graph = True
85
+ model = 'facebook/opt-125m ' , attn_cls = "paged" , disable_cuda_graph = True
83
86
)
84
87
)
85
- worker_socket_addr = []
86
- migration_config = MigrationConfig (migration_backend = backend , migration_cache_blocks = 8 )
88
+
89
+ worker_socket_addrs = []
90
+ migration_config = EngineManagerArgs (migration_backend = backend , migration_cache_blocks = 8 ).create_migration_config ()
87
91
92
+ set_start_method ("spawn" , force = True )
88
93
backends : List [Process ] = []
89
94
for i in range (worker_count ):
90
95
instance_id = random_uuid ()
91
- p = Process (target = worker_main , args = (0 , worker_args , instance_id , migration_config ))
92
- worker_socket_addr .append (f"unix://{ worker_args .worker_socket_path } .{ instance_id } .{ 0 } " )
96
+ worker_args .device = f"cuda:{ i } "
97
+ worker_socket_addrs .append (f"localhost:{ 1234 + i } " )
98
+
99
+ p = Process (target = worker_main , args = (worker_socket_addrs [- 1 ], i , worker_args , instance_id , migration_config ))
93
100
p .start ()
94
101
backends .append (p )
95
102
96
103
# assert all(ray.get([worker0.execute_method.remote('rebuild_migration_backend',
97
104
# instance_rank=instance_rank, group_name=group_name),
98
105
# worker1.execute_method.remote('rebuild_migration_backend',
99
106
# instance_rank=instance_rank, group_name=group_name)]))
107
+ time .sleep (5 )
100
108
101
109
for i in range (worker_count ):
102
- with grpc .insecure_channel (worker_socket_addr [i ]) as channel :
110
+ with grpc .insecure_channel (worker_socket_addrs [i ]) as channel :
103
111
stub = migration_worker_pb2_grpc .MigrationWorkerStub (channel )
104
112
responce = stub .warmup (empty_pb2 .Empty ())
105
113
assert responce .ok
106
114
107
- with grpc .insecure_channel (worker_socket_addr [0 ]) as channel :
115
+ with grpc .insecure_channel (worker_socket_addrs [0 ]) as channel :
108
116
stub = mock_migration_worker_pb2_grpc .MockMigrationWorkerStub (channel )
109
117
responce = stub .get_kv_cache_meta (empty_pb2 .Empty ())
110
118
kv_cache_shape , dtype , total_gpu_blocks = responce .shape , responce .dtype , responce .num_gpu_blocks
111
-
119
+
112
120
dummy_key_data = torch .randn (size = kv_cache_shape , dtype = dtype )
113
121
dummy_value_data = torch .randn (size = kv_cache_shape , dtype = dtype )
114
122
115
- with grpc .insecure_channel (worker_socket_addr [0 ]) as channel :
123
+ with grpc .insecure_channel (worker_socket_addrs [0 ]) as channel :
116
124
stub = mock_migration_worker_pb2_grpc .MockMigrationWorkerStub (channel )
117
125
responce = stub .set_gpu_cache (mock_migration_worker_pb2 .KvCacheData (
118
126
key = dummy_key_data .numpy ().tobytes (),
@@ -126,15 +134,15 @@ def test_migrate_cache(backend):
126
134
dst_blocks = list (range (total_gpu_blocks ))
127
135
random .shuffle (dst_blocks )
128
136
129
- with grpc .insecure_channel (worker_socket_addr [1 ]) as channel :
137
+ with grpc .insecure_channel (worker_socket_addrs [1 ]) as channel :
130
138
src_stub = migration_worker_pb2_grpc .MigrationWorkerStub (channel )
131
139
src_stub .migrate_cache (migration_worker_pb2 .MigrateRequest (
132
- src_handlers = [None , worker_socket_addr [0 ]],
140
+ src_handlers = [None , worker_socket_addrs [0 ]],
133
141
src_blocks = list (range (total_gpu_blocks )),
134
142
dst_blocks = dst_blocks ,
135
143
))
136
144
137
- with grpc .insecure_channel (worker_socket_addr [1 ]) as channel :
145
+ with grpc .insecure_channel (worker_socket_addrs [1 ]) as channel :
138
146
stub = mock_migration_worker_pb2_grpc .MockMigrationWorkerStub (channel )
139
147
responce = stub .get_gpu_cache (empty_pb2 .Empty ())
140
148
worker1_key_data = torch .from_numpy (np .frombuffer (responce .key , dtype = dtype ))
0 commit comments