Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add guard for model launching #1680

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 74 additions & 45 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __init__(
self._main_pool.recover_sub_pool = self.recover_sub_pool

# internal states.
# temporary placeholder during model launch process:
self._model_uid_launching_guard: Dict[str, bool] = {}
# attributes maintained after model launched:
self._model_uid_to_model: Dict[str, xo.ActorRefType["ModelActor"]] = {}
self._model_uid_to_model_spec: Dict[str, ModelDescription] = {}
self._gpu_to_model_uid: Dict[int, str] = {}
Expand Down Expand Up @@ -594,10 +597,14 @@ async def launch_builtin_model(
launch_args.pop("kwargs")
launch_args.update(kwargs)

event_model_uid, _, __ = parse_replica_model_uid(model_uid)
try:
origin_uid, _, _ = parse_replica_model_uid(model_uid)
except Exception as e:
logger.exception(e)
qinxuye marked this conversation as resolved.
Show resolved Hide resolved
raise
try:
await self._event_collector_ref.report_event(
event_model_uid,
origin_uid,
Event(
event_type=EventType.INFO,
event_ts=int(time.time()),
Expand Down Expand Up @@ -640,50 +647,55 @@ async def launch_builtin_model(
assert model_uid not in self._model_uid_to_model
self._check_model_is_valid(model_name, model_format)

subpool_address, devices = await self._create_subpool(
model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx
)
if self.get_model_launch_status(model_uid) is not None:
raise ValueError(f"{model_uid} is running")

try:
origin_uid, _, _ = parse_replica_model_uid(model_uid)
model, model_description = await asyncio.to_thread(
create_model_instance,
subpool_address,
devices,
model_uid,
model_type,
model_name,
model_engine,
model_format,
model_size_in_billions,
quantization,
peft_model_config,
**kwargs,
)
await self.update_cache_status(model_name, model_description)
model_ref = await xo.create_actor(
ModelActor,
address=subpool_address,
uid=model_uid,
worker_address=self.address,
model=model,
model_description=model_description,
request_limits=request_limits,
self._model_uid_launching_guard[model_uid] = True
subpool_address, devices = await self._create_subpool(
model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx
)
await model_ref.load()
except:
logger.error(f"Failed to load model {model_uid}", exc_info=True)
self.release_devices(model_uid=model_uid)
await self._main_pool.remove_sub_pool(subpool_address)
raise

self._model_uid_to_model[model_uid] = model_ref
self._model_uid_to_model_spec[model_uid] = model_description
self._model_uid_to_addr[model_uid] = subpool_address
self._model_uid_to_recover_count.setdefault(
model_uid, MODEL_ACTOR_AUTO_RECOVER_LIMIT
)
self._model_uid_to_launch_args[model_uid] = launch_args
try:
model, model_description = await asyncio.to_thread(
create_model_instance,
subpool_address,
devices,
model_uid,
model_type,
model_name,
model_engine,
model_format,
model_size_in_billions,
quantization,
peft_model_config,
**kwargs,
)
await self.update_cache_status(model_name, model_description)
model_ref = await xo.create_actor(
ModelActor,
address=subpool_address,
uid=model_uid,
worker_address=self.address,
model=model,
model_description=model_description,
request_limits=request_limits,
)
await model_ref.load()
except:
logger.error(f"Failed to load model {model_uid}", exc_info=True)
self.release_devices(model_uid=model_uid)
await self._main_pool.remove_sub_pool(subpool_address)
raise
self._model_uid_to_model[model_uid] = model_ref
self._model_uid_to_model_spec[model_uid] = model_description
self._model_uid_to_addr[model_uid] = subpool_address
self._model_uid_to_recover_count.setdefault(
model_uid, MODEL_ACTOR_AUTO_RECOVER_LIMIT
)
self._model_uid_to_launch_args[model_uid] = launch_args
finally:
del self._model_uid_launching_guard[model_uid]

# update status to READY
abilities = await self._get_model_ability(model, model_type)
Expand All @@ -694,10 +706,13 @@ async def launch_builtin_model(

@log_async(logger=logger)
async def terminate_model(self, model_uid: str):
event_model_uid, _, __ = parse_replica_model_uid(model_uid)
# Terminate model while its launching is not allow
if model_uid in self._model_uid_launching_guard:
raise ValueError(f"{model_uid} is launching")
origin_uid, _, __ = parse_replica_model_uid(model_uid)
qinxuye marked this conversation as resolved.
Show resolved Hide resolved
try:
await self._event_collector_ref.report_event(
event_model_uid,
origin_uid,
Event(
event_type=EventType.INFO,
event_ts=int(time.time()),
Expand All @@ -708,7 +723,6 @@ async def terminate_model(self, model_uid: str):
# Report callback error can be log and ignore, should not interrupt the Process
logger.error("report_event error: %s" % (e))

origin_uid, _, _ = parse_replica_model_uid(model_uid)
await self._status_guard_ref.update_instance_info(
origin_uid, {"status": LaunchStatus.TERMINATING.name}
)
Expand Down Expand Up @@ -740,6 +754,21 @@ async def terminate_model(self, model_uid: str):
origin_uid, {"status": LaunchStatus.TERMINATED.name}
)

# Provide an interface for future version of supervisor to call
def get_model_launch_status(self, model_uid: str) -> Optional[str]:
"""
returns:
CREATING: model is launching
RREADY: model is running
None: model is not running (launch error might have happened)
"""

if model_uid in self._model_uid_launching_guard:
return LaunchStatus.CREATING.name
if model_uid in self._model_uid_to_model:
return LaunchStatus.READY.name
return None

@log_async(logger=logger)
async def list_models(self) -> Dict[str, Dict[str, Any]]:
ret = {}
Expand Down
Loading