@@ -87,27 +87,68 @@ def _get_from_waiting_queue(
8787 self ,
8888 waiting_queue : deque [RequestQueueItem ],
8989 max_req_count : int ,
90+ enable_attention_dp : bool ,
91+ all_ranks_num_active_requests : Optional [List [int ]] = None ,
9092 ) -> List [RequestQueueItem ]:
91- """Safely extracts up to max_req_count items from a deque.
92-
93+ """
9394 Args:
9495 waiting_queue: The queue to pop items from.
9596 max_req_count: Maximum items to retrieve. Returns empty list if <=0.
96-
97+ enable_attention_dp: Whether to enable attention DP scheduling.
98+ all_ranks_num_active_requests: Number of active requests for each rank.
9799 Returns:
98- List of retrieved items (may be shorter than max_req_count if queue empties first) .
100+ List of requests that can be processed .
99101 """
100- # Edge case handling
101- if max_req_count <= 0 : # Handles negative/zero counts
102+
103+ if max_req_count <= 0 :
102104 return []
103105
104- items = []
105106 req_count = 0
107+ items = []
108+ pending_requests = []
109+
110+ # Track the request with strict requirements
111+ scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests .copy (
112+ ) if enable_attention_dp else None
106113 while req_count < max_req_count and waiting_queue :
107- items .append (waiting_queue .popleft ())
108- req_count += 1
114+ req_item = waiting_queue .popleft ()
115+ can_process = self ._can_process_attention_dp_request (
116+ req_item , scheduling_all_ranks_num_active_requests
117+ ) if enable_attention_dp else True
118+
119+ if can_process :
120+ items .append (req_item )
121+ req_count += 1
122+ else :
123+ pending_requests .append (req_item )
124+
125+ # Put the pending requests back to the waiting queue
126+ # All ranks should have the same waiting queue
127+ waiting_queue .extendleft (reversed (pending_requests ))
128+
109129 return items
110130
131+ def _can_process_attention_dp_request (
132+ self , req_item : RequestQueueItem ,
133+ all_ranks_num_active_requests : List [int ]) -> bool :
134+ """Return True if the request can be processed immediately, else False."""
135+
136+ scheduling_params = getattr (req_item .request , 'py_scheduling_params' ,
137+ None )
138+ if scheduling_params is None :
139+ return True
140+
141+ target_dp_rank = scheduling_params .attention_dp_rank
142+ if target_dp_rank is None or scheduling_params .attention_dp_relax :
143+ return True
144+
145+ if all_ranks_num_active_requests [
146+ target_dp_rank ] < self .max_num_active_requests :
147+ all_ranks_num_active_requests [target_dp_rank ] += 1
148+ return True
149+
150+ return False
151+
111152 def enqueue_requests (self , requests : List [ExecutorRequest ]):
112153 req_ids = []
113154 try :
@@ -166,8 +207,12 @@ def can_enqueue_request(self) -> bool:
166207 return can_enqueue and self .dist .rank == 0
167208
168209 def _fetch_and_process_requests (
169- self , total_num_active_requests : int ,
170- total_max_num_active_requests : int ) -> List [RequestQueueItem ]:
210+ self ,
211+ total_num_active_requests : int ,
212+ total_max_num_active_requests : int ,
213+ enable_attention_dp : bool ,
214+ all_ranks_num_active_requests : Optional [List [int ]] = None
215+ ) -> List [RequestQueueItem ]:
171216 """Common logic for fetching and processing requests from the queue."""
172217 # Calculate timeout
173218 timeout = None if (total_num_active_requests == 0 ) and len (
@@ -195,7 +240,8 @@ def _fetch_and_process_requests(
195240
196241 new_requests = self ._get_from_waiting_queue (
197242 self .waiting_queue ,
198- total_max_num_active_requests - total_num_active_requests )
243+ total_max_num_active_requests - total_num_active_requests ,
244+ enable_attention_dp , all_ranks_num_active_requests )
199245
200246 # Update performance metrics
201247 if self .enable_iter_perf_stats and self .dist .rank == 0 :
@@ -218,9 +264,11 @@ def _fetch_new_requests_attention_tp(
218264 total_num_active_requests = num_active_requests
219265 total_max_num_active_requests = self .max_num_active_requests
220266
221- # Use common request fetching logic
267+ # fetch and process requests into waiting queue
222268 new_requests = self ._fetch_and_process_requests (
223- total_num_active_requests , total_max_num_active_requests )
269+ total_num_active_requests ,
270+ total_max_num_active_requests ,
271+ enable_attention_dp = False )
224272
225273 # Merge requests and add to active list
226274 merged_requests = self ._merge_requests (new_requests )
@@ -238,34 +286,84 @@ def _fetch_new_requests_attention_dp(
238286 total_num_active_requests = sum (all_ranks_num_active_requests )
239287 total_max_num_active_requests = self .dist .tp_size * self .max_num_active_requests
240288
241- # Use common request fetching logic
289+ # fetch and process requests into waiting queue
242290 new_requests = self ._fetch_and_process_requests (
243- total_num_active_requests , total_max_num_active_requests )
291+ total_num_active_requests ,
292+ total_max_num_active_requests ,
293+ enable_attention_dp = True ,
294+ all_ranks_num_active_requests = all_ranks_num_active_requests )
244295
245- # Balance requests across ranks
246- num_new_requests_all_ranks = len (new_requests )
247- self .expected_num_active_requests = max (
248- (total_num_active_requests + num_new_requests_all_ranks +
249- self .dist .tp_size - 1 ) // self .dist .tp_size ,
250- max (all_ranks_num_active_requests ),
251- )
252-
253- new_requests_cur_rank = self ._balance_requests_across_ranks (
296+ # Schedule attention dp requests
297+ all_ranks_new_requests = self ._schedule_attention_dp_requests (
254298 new_requests , all_ranks_num_active_requests )
299+ new_requests_cur_rank = all_ranks_new_requests [self .dist .tp_rank ]
255300
256301 # Update performance metrics
257302 if self .enable_iter_perf_stats and self .start_times :
258303 self ._update_new_active_requests_queue_latency (
259304 new_requests_cur_rank )
260305
261306 # Update counters
262- self .num_fetch_requests += num_new_requests_all_ranks
307+ self .num_fetch_requests += len ( new_requests )
263308 self .num_fetch_requests_cur_rank += len (new_requests_cur_rank )
264309
265310 # Merge requests and add to active list
266311 new_requests_cur_rank = self ._merge_requests (new_requests_cur_rank )
267312 return new_requests_cur_rank
268313
314+ def _schedule_attention_dp_requests (
315+ self , new_requests : List [RequestQueueItem ],
316+ all_ranks_num_active_requests : List [int ]) -> List [RequestQueueItem ]:
317+ """Schedule attention dp requests."""
318+
319+ # Map from ranks to new requests
320+ all_ranks_new_requests = {
321+ tp_rank : []
322+ for tp_rank in range (self .dist .tp_size )
323+ }
324+
325+ # Prioritize the requests that are not in relax mode
326+ def get_relax_value (req_item ):
327+ scheduling_params = getattr (req_item .request ,
328+ 'py_scheduling_params' , None )
329+ if scheduling_params is None :
330+ return True
331+ return scheduling_params .attention_dp_relax
332+
333+ new_requests = sorted (new_requests , key = get_relax_value , reverse = True )
334+
335+ # Try to put the requests to the target dp rank until the max_num_active_requests is reached
336+ remaining_unscheduled = []
337+ for req_item in new_requests :
338+ scheduled = False
339+ scheduling_params = getattr (req_item .request ,
340+ 'py_scheduling_params' , None )
341+ if scheduling_params is not None :
342+ target_dp_rank = scheduling_params .attention_dp_rank
343+ if target_dp_rank is not None and all_ranks_num_active_requests [
344+ target_dp_rank ] < self .max_num_active_requests :
345+ all_ranks_num_active_requests [target_dp_rank ] += 1
346+ scheduled = True
347+ all_ranks_new_requests [target_dp_rank ].append (req_item )
348+
349+ if not scheduled :
350+ remaining_unscheduled .append (req_item )
351+
352+ # Balance the remaining unscheduled requests across ranks
353+ num_new_requests_all_ranks = len (remaining_unscheduled )
354+ total_num_active_requests = sum (all_ranks_num_active_requests )
355+ self .expected_num_active_requests = max (
356+ (total_num_active_requests + num_new_requests_all_ranks +
357+ self .dist .tp_size - 1 ) // self .dist .tp_size ,
358+ max (all_ranks_num_active_requests ),
359+ )
360+
361+ all_ranks_new_requests = self ._balance_requests_across_ranks (
362+ remaining_unscheduled , all_ranks_new_requests ,
363+ all_ranks_num_active_requests )
364+
365+ return all_ranks_new_requests
366+
269367 def _handle_request_broadcasting (self ,
270368 new_requests : List [RequestQueueItem ]):
271369 """Handle broadcasting of requests and Python objects across ranks."""
@@ -274,8 +372,13 @@ def _handle_request_broadcasting(self,
274372 new_requests , "py_logits_post_processors" )
275373 py_multimodal_data = self ._collect_py_objects_from_requests (
276374 new_requests , "py_multimodal_data" )
375+ py_scheduling_params = self ._collect_py_objects_from_requests (
376+ new_requests , "py_scheduling_params" )
277377 py_request_objects = tuple (
278- filter (None , [py_logits_post_processors , py_multimodal_data ]))
378+ filter (None , [
379+ py_logits_post_processors , py_multimodal_data ,
380+ py_scheduling_params
381+ ]))
279382 else :
280383 py_request_objects = None
281384
@@ -314,28 +417,30 @@ def _validate_and_filter_requests(
314417
315418 def _balance_requests_across_ranks (
316419 self , new_requests : List [RequestQueueItem ],
420+ all_ranks_new_requests : Dict [int , List [RequestQueueItem ]],
317421 all_ranks_num_active_requests : List [int ]) -> List [RequestQueueItem ]:
318422 """Balance requests across ranks for attention DP."""
319- new_requests_cur_rank = []
320-
321- if new_requests and self .expected_num_active_requests > all_ranks_num_active_requests [
322- self .dist .tp_rank ]:
423+ if new_requests :
323424 # Balance context tokens across ranks using heap
324425 HeapVal = namedtuple (
325426 'HeapVal' ,
326427 ['num_tokens' , 'num_requests' , 'rank' , 'request_list' ])
327428
328429 all_ranks_new_requests_heap = [
329- HeapVal (0 , self . expected_num_active_requests - val , tp_rank , [])
430+ HeapVal (0 , val , tp_rank , [])
330431 for tp_rank , val in enumerate (all_ranks_num_active_requests )
331432 ]
332433
333- new_requests_cur_rank = all_ranks_new_requests_heap [
334- self .dist .tp_rank ].request_list
335434 all_ranks_new_requests_heap = [
336435 val for val in all_ranks_new_requests_heap
337- if val .num_requests > 0
436+ if val .num_requests < self . expected_num_active_requests
338437 ]
438+
439+ all_ranks_new_scheduled_requests = {
440+ val .rank : val .request_list
441+ for val in all_ranks_new_requests_heap
442+ }
443+
339444 heapq .heapify (all_ranks_new_requests_heap )
340445
341446 # Sort by token count (descending) for better load balancing
@@ -351,17 +456,22 @@ def _balance_requests_across_ranks(
351456 token_count = len (
352457 getattr (req_item .request , 'input_token_ids' ,
353458 [])) if req_item .request else 0
459+ # Update the heap value with the new request
354460 val = val ._replace (
355461 num_tokens = val .num_tokens + token_count ,
356- num_requests = val .num_requests - 1 ,
462+ num_requests = val .num_requests + 1 ,
357463 )
464+
358465 val .request_list .append (req_item )
359- if val .num_requests > 0 :
466+ # If rank still has room for new requests, push back into heap
467+ if val .num_requests < self .expected_num_active_requests :
360468 heapq .heappush (all_ranks_new_requests_heap , val )
361- elif val .rank == self .dist .tp_rank :
362- break
363469
364- return new_requests_cur_rank
470+ # Extend all_ranks_new_requests with the new requests that have been scheduled
471+ for rank , reqs in all_ranks_new_scheduled_requests .items ():
472+ all_ranks_new_requests [rank ].extend (reqs )
473+
474+ return all_ranks_new_requests
365475
366476 def _collect_py_objects_from_requests (
367477 self , requests : List [RequestQueueItem ],
0 commit comments