diff --git a/saxml/server/model_service_base.py b/saxml/server/model_service_base.py index d11eb27..c75bc62 100644 --- a/saxml/server/model_service_base.py +++ b/saxml/server/model_service_base.py @@ -93,6 +93,7 @@ def _cancelled(rpc): def _notify(): for t in rpc_tasks: + t.release_device_resource() t.done(utils.cancelled()) if all(_cancelled(t.rpc) for t in rpc_tasks): @@ -446,9 +447,27 @@ def _finish_batch(): with self._global_live_batches_lock: self._global_live_batches -= 1 + next_rpc_tasks = [] while True: batch_sem.acquire() - rpc_tasks = method.queue.take_batch(batch_size) + if not next_rpc_tasks: + next_rpc_tasks = method.queue.take_batch(batch_size) + if key.name == MethodName.BATCHED_LM_GENERATE: + # Make sure total slots in batch <= max available slots. + assert model is not None + cache_slots = model.method(key.model_method).num_cache_slots + rpc_tasks = [] + total_slots = 0 + for t in next_rpc_tasks: + if total_slots + t.aux['slot_count'] > cache_slots: + break + rpc_tasks.append(t) + total_slots += t.aux['slot_count'] + next_rpc_tasks = next_rpc_tasks[len(rpc_tasks):] + # Larger slot_count comes first. + rpc_tasks = sorted(rpc_tasks, key=lambda t: -t.aux['slot_count']) + else: + rpc_tasks, next_rpc_tasks = next_rpc_tasks, [] if method.admissioner.is_shutdown(): # This must be an empty task generated after shutdown to unblock this # thread. @@ -498,6 +517,7 @@ def _finish_batch(): error_msg = f'Preprocessing error: {e}\n{traceback.format_exc()}' logging.error(error_msg) for rpc_task in rpc_tasks: + rpc_task.release_device_resource() rpc_task.done(utils.internal_error(error_msg)) if not presync: _finish_batch() @@ -530,7 +550,7 @@ def unregister_method(self, key: MethodKey): method = self._per_method_queues[key] method.admissioner.shutdown() # An empty task to unblock the batcher thread. - method.queue.send(None, None, None, None) + method.queue.send(None, None, None, None, None) del self._per_method_queues[key] def has_method(self, key: MethodKey) -> bool: @@ -590,6 +610,8 @@ def done(status, *args, **kwargs): # Check ACLs. model: servable_model.ServableModel | None = method.model + acquire_count = 1 + aux = None if model is not None and key.model_method is not None: model_method_name: str = key.model_method @@ -613,8 +635,25 @@ def done(status, *args, **kwargs): ) if not validate_status.ok(): return done(validate_status) - - success, active = method.admissioner.acquire(blocking=False) + if key.name == MethodName.BATCHED_LM_GENERATE: + acquire_count = servable_method.num_cache_slots_for_request(req) + if acquire_count > servable_method.num_cache_slots: + return done( + utils.invalid_arg( + 'More slots requested than max:' + f' {acquire_count} {servable_method.num_cache_slots}' + ) + ) + if acquire_count > 1 and servable_method.streamable_output: + return done( + utils.unimplemented( + 'Dynamic sample count not supported in streamable methods' + ) + ) + aux = {'slot_count': acquire_count, 'finished_results': []} + success, active = method.admissioner.acquire( + blocking=False, count=acquire_count + ) if not active: return done(utils.not_found(f'method {key} is unloaded')) @@ -623,11 +662,12 @@ def done(status, *args, **kwargs): utils.resource_exhausted(f'Too many requests: {key} {method.limit()}') ) - def _done(status: utils.Status, *args, **kwargs): - done(status, *args, **kwargs) - method.admissioner.release() + def _release_device_resource(): + method.admissioner.release(count=1) - method.queue.send(rpc, req, resp, _done, trace_callback) + method.queue.send( + rpc, req, resp, _release_device_resource, done, trace_callback, aux + ) def get_batch(self) -> Batch: """Dequeues an available batch.""" @@ -1839,6 +1879,11 @@ def _postprocess(): utils.traceprint_all( batch.rpc_tasks, f'After output_to_host: {batch.method}' ) + for task in batch.rpc_tasks: + task.release_device_resource() + except Exception as e: # pylint: disable=broad-except + logging.fatal('Error during waiting for device result %s', e) + try: if not pre_process_failure: # No more result for streaming. if streaming_done is not None: @@ -1988,6 +2033,7 @@ def _run_primary_worker_loop(self): dict(request.overrides), prng_seed, ) + task.release_device_resource() task.done(utils.ok()) except ValueError as e: self._log_exception( @@ -1999,6 +2045,7 @@ def _run_primary_worker_loop(self): request.model_path, e, ) + task.release_device_resource() task.done(utils.invalid_arg(f'{e}')) except Exception as e: # pylint: disable=broad-except self._log_exception( @@ -2010,6 +2057,7 @@ def _run_primary_worker_loop(self): request.model_path, e, ) + task.release_device_resource() task.done(utils.internal_error(f'Loading error: {e}')) case MethodName.UNLOAD: with batch: @@ -2023,6 +2071,7 @@ def _run_primary_worker_loop(self): raise ValueError('model_key is not specified.') self._inform_secondary_hosts(batch.method.name, model_key) self._loaded_models.unload(model_key) + task.release_device_resource() task.done(utils.ok()) except ValueError as e: logging.exception( @@ -2030,6 +2079,7 @@ def _run_primary_worker_loop(self): model_key, e, ) + task.release_device_resource() task.done(utils.invalid_arg(f'Unloading error: {e}')) except Exception as e: # pylint: disable=broad-except self._log_exception( @@ -2037,6 +2087,7 @@ def _run_primary_worker_loop(self): model_key, e, ) + task.release_device_resource() task.done(utils.internal_error(f'Unloading error: {e}')) case MethodName.EXPORT: with batch: @@ -2057,6 +2108,7 @@ def _run_primary_worker_loop(self): self._worker_thread_exception = e break exporter.finalize_export(*export_args) + task.release_device_resource() task.done(utils.ok()) except Exception as e: # pylint: disable=broad-except self._log_exception( @@ -2070,6 +2122,7 @@ def _run_primary_worker_loop(self): request.export_path, e, ) + task.release_device_resource() task.done(exporter.export_error_to_status(e)) case MethodName.SAVE: with batch: @@ -2081,6 +2134,7 @@ def _run_primary_worker_loop(self): batch.method.name, request.model_key, request.checkpoint_path ) self._save_model(request.model_key, request.checkpoint_path) + task.release_device_resource() task.done(utils.ok()) except ValueError as e: self._log_exception( @@ -2088,6 +2142,7 @@ def _run_primary_worker_loop(self): request.model_key, e, ) + task.release_device_resource() task.done(utils.invalid_arg(f'Save checkpoint error: {e}')) except Exception as e: # pylint: disable=broad-except self._log_exception( @@ -2098,6 +2153,7 @@ def _run_primary_worker_loop(self): request.model_key, e, ) + task.release_device_resource() task.done(utils.internal_error(f'Saving checkpoint error: {e}')) case MethodName.TERMINATE: with batch: @@ -2212,32 +2268,39 @@ def _run_prefill_insert_loop( while True: # Block if there is no prefill requests. request = state.prefill_queue.get() - n_tasks = len(request.rpc_tasks) - # Block if there is no available cache slot. - slots = [] - for i in range(n_tasks): - slots.append(state.available_slots.get()) - # Wake up the generation loop when no available slot is left - if state.available_slots.empty() and i < n_tasks - 1: - state.pending_insert = False - with state.generate_cv: - state.generate_cv.notify() - slots = tuple(slots) + # Tasks are sorted by slot_count. + max_slots_per_task = request.rpc_tasks[0].aux['slot_count'] + slots_to_assign = sum(t.aux['slot_count'] for t in request.rpc_tasks) + assert slots_to_assign <= state.num_cache_slots + per_sample_slots = [[] for _ in range(max_slots_per_task)] + for t in request.rpc_tasks: + n_samples = t.aux['slot_count'] + assert n_samples <= max_slots_per_task + for i in range(n_samples): + # Block if there is no available cache slot. + per_sample_slots[i].append(state.available_slots.get()) + slots_to_assign -= 1 + # Wake up the generation loop when no available slot is left + if state.available_slots.empty() and slots_to_assign > 0: + state.pending_insert = False + with state.generate_cv: + state.generate_cv.notify() prefill_dequeue_time = time.time() # Take unused slots. - logging.info('Taking slots %s', slots) + logging.info('Taking slots %s', per_sample_slots) # Atomic mutation. Safe to be outside of generate_cv guard. state.pending_insert = True + per_sample_results = [] with state.generate_cv: with self._device_compute_mutex: self._inform_secondary_hosts( MethodName.PREFILL_INSERT, state.model_key, state.model_method, - ','.join([str(s) for s in slots]), + json.dumps(per_sample_slots), skip_host_sync=False, ) @@ -2248,25 +2311,43 @@ def _run_prefill_insert_loop( inputs=request.preprocessed_inputs, ) insert_start_time = time.time() - method_obj.insert(prefix_cache, slots) + for i, slots in enumerate(per_sample_slots): + if i > 0: + scores, tokens, prefix_cache = method_obj.resample_initial_tokens( + prefix_cache + ) + method_obj.insert(prefix_cache, tuple(slots)) + per_sample_results.append((scores, tokens)) del prefix_cache - host_tokens = np.array(tokens.addressable_data(0)) - host_scores = np.array(scores.addressable_data(0)) - state.decoded_tokens[slots, 0] = host_tokens[:n_tasks] - state.scores[slots, ...] = host_scores[:n_tasks] + host_tokens = [] + host_scores = [] + all_slots = [] + expanded_tasks = [] + for (scores, tokens), slots in zip( + per_sample_results, per_sample_slots + ): + expanded_tasks.extend(request.rpc_tasks[:len(slots)]) + host_tokens.append(np.array(tokens.addressable_data(0))[:len(slots)]) + host_scores.append(np.array(scores.addressable_data(0))[:len(slots)]) + all_slots.extend(slots) + host_tokens = np.concatenate(host_tokens, axis=0) + host_scores = np.concatenate(host_scores, axis=0) + all_slots = tuple(all_slots) + state.decoded_tokens[all_slots, 0] = host_tokens + state.scores[all_slots, ...] = host_scores self._run_post_prefill_async( state, - slots, + all_slots, copy.deepcopy(host_tokens), copy.deepcopy(host_scores), - request.rpc_tasks, + expanded_tasks, ) - state.steps[slots, ...] = 1 + state.steps[all_slots, ...] = 1 # Must set slots_in_use in the end. - state.slots_in_use[slots, ...] = 1 + state.slots_in_use[all_slots, ...] = 1 # Update stats state.update_stats( @@ -2307,6 +2388,7 @@ def _postprocess(): output_strings = state.method.detokenize(tokens) for i, slot in enumerate(slots): + assert state.rpc_tasks[slot].aux['slot_count'] == 1 resp = copy.deepcopy(state.rpc_tasks[slot].response) outputs = output_strings[i], scores[i] self._model_services[state.service_id].FillRPCResponse( @@ -2392,7 +2474,7 @@ def _run_generation_loop( # Release the slots. for slot in np.flatnonzero(done): logging.info('Releasing slot %d.', slot) - state.available_slots.put(slot) + state.available_slots.put(int(slot)) prev_result = current_result prev_mask = current_mask @@ -2409,26 +2491,37 @@ def _run_post_generate_async( def _postprocess(): # If any of the sequences in the batch is done, return the response # and reset the cache slot. - if state.method.service_id() == 'custom': - method_outputs = state.method.post_processing( - {'tokens': sequences, 'scores': scores} - ) - else: - method_outputs = state.method.detokenize(sequences) for idx, slot in enumerate(np.flatnonzero(done)): + rpc_task = state.rpc_tasks[slot] + assert isinstance(rpc_task, utils.RpcQueueTask) + rpc_task.release_device_resource() + rpc_task.aux['finished_results'].append((sequences[idx], scores[idx])) + if rpc_task.aux['slot_count'] > len(rpc_task.aux['finished_results']): + assert not state.method.streamable_output + continue + # [num_samples, ...] + seqs = np.stack( + [x for x, _ in rpc_task.aux['finished_results']], axis=0 + ) + scrs = np.stack( + [x for _, x in rpc_task.aux['finished_results']], axis=0 + ) if state.method.service_id() == 'custom': - outputs = method_outputs[idx] + outputs = state.method.post_processing({ + 'tokens': np.expand_dims(seqs, 0), # [batch1, num_samples, ...] + 'scores': np.expand_dims(scrs, 0), + })[0] else: - outputs = [method_outputs[idx]], [scores[idx]] + outputs = state.method.detokenize(seqs), scrs if state.method.streamable_output: # send response back to generate_stream - resp = copy.deepcopy(state.rpc_tasks[slot].response) + resp = copy.deepcopy(rpc_task.response) self._model_services[state.service_id].FillRPCResponse( state.model_method, outputs, resp ) try: - state.rpc_tasks[slot].done(utils.ok(), resp=resp) + rpc_task.done(utils.ok(), resp=resp) except Exception as e: # pylint: disable=broad-except self._log_exception( 'Error occurred: %s, error: %s', state.model_key, e @@ -2436,7 +2529,7 @@ def _postprocess(): # send response done back to generate_stream try: - state.rpc_tasks[slot].done(utils.ok()) + rpc_task.done(utils.ok()) except Exception as e: # pylint: disable=broad-except self._log_exception( 'Error occurred: %s, error: %s', state.model_key, e @@ -2445,10 +2538,10 @@ def _postprocess(): else: # send response back to generate self._model_services[state.service_id].FillRPCResponse( - state.model_method, outputs, state.rpc_tasks[slot].response + state.model_method, outputs, rpc_task.response ) try: - state.rpc_tasks[slot].done(utils.ok()) + rpc_task.done(utils.ok()) except Exception as e: # pylint: disable=broad-except self._log_exception( 'Error occurred: %s, error: %s', state.model_key, e @@ -2560,13 +2653,16 @@ def _run_secondary_worker_loop(self): break case MethodName.PREFILL_INSERT: try: - model_key, model_method, slots = msgs - slots = [int(s) for s in slots.split(',')] + model_key, model_method, per_sample_slots = msgs + per_sample_slots = json.loads(per_sample_slots) method_obj = self._loaded_models.get_model(model_key).method( model_method ) _, _, state = method_obj.prefill_with_dummy() - method_obj.insert(state, slots) + for i, slots in enumerate(per_sample_slots): + if i > 0: + _, _, state = method_obj.resample_initial_tokens(state) + method_obj.insert(state, tuple(slots)) del state except Exception as e: # pylint: disable=broad-except self._worker_thread_exception = e diff --git a/saxml/server/servable_model.py b/saxml/server/servable_model.py index efb6239..d2a793b 100644 --- a/saxml/server/servable_model.py +++ b/saxml/server/servable_model.py @@ -313,6 +313,10 @@ def continuous_batching(self) -> bool: def num_cache_slots(self) -> int: raise NotImplementedError('num_cache_slots not implemented') + def num_cache_slots_for_request(self, req: Any) -> int: + del req + return 1 + @property def max_decode_steps(self) -> int: raise NotImplementedError('max_decode_steps not implemented') @@ -345,6 +349,21 @@ def prefill_with_dummy( """ raise NotImplementedError('prefill_with_dummy not implemented') + def resample_initial_tokens( + self, cache: DeviceTensors + ) -> tuple[DeviceTensors, DeviceTensors, DeviceTensors]: + """Resamples the initial tokens from prefill with different randomness. + + Args: + cache: Original prefilled KV state. + + Returns: + scores: Log probability [B] of sampled next tokens. + token: Next token [B] of the prompt, sampled by model's default sampler. + cache: Updated KV state. + """ + raise NotImplementedError('resample_initial_tokens not implemented') + def insert( self, prefix_state: DeviceTensors, slot: int | Sequence[int] ) -> None: diff --git a/saxml/server/utils.py b/saxml/server/utils.py index d2d7aad..2b8736b 100644 --- a/saxml/server/utils.py +++ b/saxml/server/utils.py @@ -96,8 +96,11 @@ class RpcQueueTask: rpc: Optional[RPCContext] request: Optional[message.Message] response: Optional[message.Message] + release_device_resource: Optional[Callable[[], None]] done: Optional[StatusCallback] tc: Optional[TracerPrintCallback] + # Any additional data. + aux: dict[str, Any] = dataclasses.field(default_factory=dict) def traceprint_all(rpc_tasks: Sequence[RpcQueueTask], msg: str): @@ -119,8 +122,10 @@ def send( rpc: Optional[RPCContext], request: Optional[message.Message], response: Optional[message.Message], + release_device_resource: Optional[Callable[[], None]], done: Optional[StatusCallback], tc: Optional[TracerPrintCallback] = None, + aux: Optional[dict[str, Any]] = None, ): """Called from RPC handler to schedule a task for processing. @@ -128,10 +133,23 @@ def send( rpc: the rpc object. request: request protocol message response: response protocol message + release_device_resource: callback for releasing device resource to control + the maximum live batches. done: A callback when the rpc handling is done. tc: optional TracerPrintCallback object. + aux: any additional data. """ - self._queue.put(RpcQueueTask(rpc, request, response, done, tc)) + self._queue.put( + RpcQueueTask( + rpc, + request, + response, + release_device_resource, + done, + tc, + aux=aux, + ) + ) def take_batch(self, batch_size: int) -> List[RpcQueueTask]: """Returns up to batch_size RpcQueueTask objects from the queue. @@ -239,29 +257,30 @@ def __init__(self, limit): self._active = True self._shutdown = False - def acquire(self, blocking: bool = True) -> Tuple[bool, bool]: + def acquire(self, blocking: bool = True, count: int = 1) -> Tuple[bool, bool]: """Acquires resource. Args: blocking: whether the invocation is blocking. + count: number of resources to acquire. Returns: A tuple of 2 bools. The first indicates if it's successful, and the second indicates if the resource is still active. """ with self._cv: - while self._count >= self._limit: + while self._count + count > self._limit: if not blocking: return False, self._active self._cv.wait() if not self._active: return False, False - self._count += 1 + self._count += count return True, True - def release(self): + def release(self, count: int = 1): with self._cv: - self._count -= 1 + self._count -= count self._cv.notify_all() def is_shutdown(self):