diff --git a/src/ai/backend/client/cli/run.py b/src/ai/backend/client/cli/run.py index aa0eacbe..74a4d217 100644 --- a/src/ai/backend/client/cli/run.py +++ b/src/ai/backend/client/cli/run.py @@ -15,6 +15,7 @@ Sequence, Tuple, ) +from tqdm import tqdm import aiohttp import click @@ -599,10 +600,45 @@ async def _run(session, idx, name, envs, except Exception as e: print_fail('[{0}] {1}'.format(idx, e)) return + + async def display_kernel_pulling(compute_session: AsyncSession.ComputeSession) -> bool: + try: + bgtask = compute_session.backgroundtask + except Exception as e: + print_error(e) + return False + else: + with tqdm(total=100, unit='%') as pbar: + async with bgtask.listen_events() as response: + async for ev in response: + progress = json.loads(ev.data) + if ev.event == 'bgtask_updated': + current = progress['current_progress'] + total = progress['total_progress'] + if total == 0: + pbar.n = 0 + else: + pbar.n = round(current / total * 100, 2) + pbar.update(0) + pbar.refresh() + elif ev.event == 'bgtask_done': + pbar.n = 100 + pbar.update(0) + pbar.refresh() + pbar.clear() + compute_session = await session.ComputeSession.get_or_create( + image, + name=name, + ) + await asyncio.sleep(0.1) + return True + if compute_session.status == 'PENDING': print_info('Session ID {0} is enqueued for scheduling.' .format(name)) - return + result = await display_kernel_pulling(compute_session) + if not result: + return elif compute_session.status == 'SCHEDULED': print_info('Session ID {0} is scheduled and about to be started.' .format(name)) @@ -623,7 +659,9 @@ async def _run(session, idx, name, envs, elif compute_session.status == 'TIMEOUT': print_info('Session ID {0} is still on the job queue.' .format(name)) - return + result = await display_kernel_pulling(compute_session) + if not result: + return elif compute_session.status in ('ERROR', 'CANCELLED'): print_fail('Session ID {0} has an error during scheduling/startup or cancelled.' .format(name)) diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index 8f633d67..d2371482 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -301,28 +301,6 @@ async def get_or_create( rqst.set_json(params) async with rqst.fetch() as resp: data = await resp.json() - if 'background_task' in data: - with tqdm(total=100, unit='%') as pbar: - task_id = data['background_task'] - bgtask = resp.session.BackgroundTask(task_id) - async with bgtask.listen_events() as response: - async for ev in response: - progress = json.loads(ev.data) - if ev.event == 'bgtask_updated': - current = progress['current_progress'] - total = progress['total_progress'] - if total == 0: - total = 1e-2 - pbar.n = round(current / total * 100, 2) - pbar.update(0) - pbar.refresh() - elif ev.event == 'bgtask_done': - pbar.n = 100.0 - pbar.update(0) - pbar.refresh() - pbar.clear() - async with rqst.fetch() as resp: - data = await resp.json() o = cls(name, owner_access_key) # type: ignore if api_session.get().api_version[0] >= 5: o.id = UUID(data['sessionId']) @@ -331,6 +309,9 @@ async def get_or_create( o.service_ports = data.get('servicePorts', []) o.domain = domain_name o.group = group_name + if 'background_task' in data: + task_id = data['background_task'] + o.backgroundtask = resp.session.BackgroundTask(task_id) return o @api_function