From 8f845a4e82b4da3b4af645a0c18424cff107b0ab Mon Sep 17 00:00:00 2001 From: plan Date: Mon, 24 Jun 2024 17:45:39 +0800 Subject: [PATCH] worker: Add guard for model launching Because model launching is a long process (download model, loading into GPU). client might encounter network error in the middle while worker is processing, add a guard the prevent duplicate operation with the same model_uid. Provide an rpc call get_model_launch_status() to return LuanchStatus, to determine whether worker is still working on this model_uid. --- xinference/core/worker.py | 119 ++++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 45 deletions(-) diff --git a/xinference/core/worker.py b/xinference/core/worker.py index 303550d197..dac5e3c288 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -73,6 +73,9 @@ def __init__( self._main_pool.recover_sub_pool = self.recover_sub_pool # internal states. + # temporary placeholder during model launch process: + self._model_uid_launching_guard: Dict[str, bool] = {} + # attributes maintained after model launched: self._model_uid_to_model: Dict[str, xo.ActorRefType["ModelActor"]] = {} self._model_uid_to_model_spec: Dict[str, ModelDescription] = {} self._gpu_to_model_uid: Dict[int, str] = {} @@ -594,10 +597,14 @@ async def launch_builtin_model( launch_args.pop("kwargs") launch_args.update(kwargs) - event_model_uid, _, __ = parse_replica_model_uid(model_uid) + try: + origin_uid, _, _ = parse_replica_model_uid(model_uid) + except Exception as e: + logger.exception(e) + raise try: await self._event_collector_ref.report_event( - event_model_uid, + origin_uid, Event( event_type=EventType.INFO, event_ts=int(time.time()), @@ -640,50 +647,55 @@ async def launch_builtin_model( assert model_uid not in self._model_uid_to_model self._check_model_is_valid(model_name, model_format) - subpool_address, devices = await self._create_subpool( - model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx - ) + if self.get_model_launch_status(model_uid) is not None: + raise ValueError(f"{model_uid} is running") try: - origin_uid, _, _ = parse_replica_model_uid(model_uid) - model, model_description = await asyncio.to_thread( - create_model_instance, - subpool_address, - devices, - model_uid, - model_type, - model_name, - model_engine, - model_format, - model_size_in_billions, - quantization, - peft_model_config, - **kwargs, - ) - await self.update_cache_status(model_name, model_description) - model_ref = await xo.create_actor( - ModelActor, - address=subpool_address, - uid=model_uid, - worker_address=self.address, - model=model, - model_description=model_description, - request_limits=request_limits, + self._model_uid_launching_guard[model_uid] = True + subpool_address, devices = await self._create_subpool( + model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx ) - await model_ref.load() - except: - logger.error(f"Failed to load model {model_uid}", exc_info=True) - self.release_devices(model_uid=model_uid) - await self._main_pool.remove_sub_pool(subpool_address) - raise - self._model_uid_to_model[model_uid] = model_ref - self._model_uid_to_model_spec[model_uid] = model_description - self._model_uid_to_addr[model_uid] = subpool_address - self._model_uid_to_recover_count.setdefault( - model_uid, MODEL_ACTOR_AUTO_RECOVER_LIMIT - ) - self._model_uid_to_launch_args[model_uid] = launch_args + try: + model, model_description = await asyncio.to_thread( + create_model_instance, + subpool_address, + devices, + model_uid, + model_type, + model_name, + model_engine, + model_format, + model_size_in_billions, + quantization, + peft_model_config, + **kwargs, + ) + await self.update_cache_status(model_name, model_description) + model_ref = await xo.create_actor( + ModelActor, + address=subpool_address, + uid=model_uid, + worker_address=self.address, + model=model, + model_description=model_description, + request_limits=request_limits, + ) + await model_ref.load() + except: + logger.error(f"Failed to load model {model_uid}", exc_info=True) + self.release_devices(model_uid=model_uid) + await self._main_pool.remove_sub_pool(subpool_address) + raise + self._model_uid_to_model[model_uid] = model_ref + self._model_uid_to_model_spec[model_uid] = model_description + self._model_uid_to_addr[model_uid] = subpool_address + self._model_uid_to_recover_count.setdefault( + model_uid, MODEL_ACTOR_AUTO_RECOVER_LIMIT + ) + self._model_uid_to_launch_args[model_uid] = launch_args + finally: + del self._model_uid_launching_guard[model_uid] # update status to READY abilities = await self._get_model_ability(model, model_type) @@ -694,10 +706,13 @@ async def launch_builtin_model( @log_async(logger=logger) async def terminate_model(self, model_uid: str): - event_model_uid, _, __ = parse_replica_model_uid(model_uid) + # Terminate model while its launching is not allow + if model_uid in self._model_uid_launching_guard: + raise ValueError(f"{model_uid} is launching") + origin_uid, _, __ = parse_replica_model_uid(model_uid) try: await self._event_collector_ref.report_event( - event_model_uid, + origin_uid, Event( event_type=EventType.INFO, event_ts=int(time.time()), @@ -708,7 +723,6 @@ async def terminate_model(self, model_uid: str): # Report callback error can be log and ignore, should not interrupt the Process logger.error("report_event error: %s" % (e)) - origin_uid, _, _ = parse_replica_model_uid(model_uid) await self._status_guard_ref.update_instance_info( origin_uid, {"status": LaunchStatus.TERMINATING.name} ) @@ -740,6 +754,21 @@ async def terminate_model(self, model_uid: str): origin_uid, {"status": LaunchStatus.TERMINATED.name} ) + # Provide an interface for future version of supervisor to call + def get_model_launch_status(self, model_uid: str) -> Optional[str]: + """ + returns: + CREATING: model is launching + RREADY: model is running + None: model is not running (launch error might have happened) + """ + + if model_uid in self._model_uid_launching_guard: + return LaunchStatus.CREATING.name + if model_uid in self._model_uid_to_model: + return LaunchStatus.READY.name + return None + @log_async(logger=logger) async def list_models(self) -> Dict[str, Dict[str, Any]]: ret = {}