Skip to content
This repository has been archived by the owner on Sep 22, 2023. It is now read-only.

Commit

Permalink
fix:separate the console output handling and the abstract bgtask hand…
Browse files Browse the repository at this point in the history
…ling.

fix: over tab
  • Loading branch information
youngjun0627 committed Sep 12, 2021
1 parent 2694168 commit 171918b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
42 changes: 40 additions & 2 deletions src/ai/backend/client/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Sequence,
Tuple,
)
from tqdm import tqdm

import aiohttp
import click
Expand Down Expand Up @@ -599,10 +600,45 @@ async def _run(session, idx, name, envs,
except Exception as e:
print_fail('[{0}] {1}'.format(idx, e))
return

async def display_kernel_pulling(compute_session: AsyncSession.ComputeSession) -> bool:
try:
bgtask = compute_session.backgroundtask
except Exception as e:
print_error(e)
return False
else:
with tqdm(total=100, unit='%') as pbar:
async with bgtask.listen_events() as response:
async for ev in response:
progress = json.loads(ev.data)
if ev.event == 'bgtask_updated':
current = progress['current_progress']
total = progress['total_progress']
if total == 0:
pbar.n = 0
else:
pbar.n = round(current / total * 100, 2)
pbar.update(0)
pbar.refresh()
elif ev.event == 'bgtask_done':
pbar.n = 100
pbar.update(0)
pbar.refresh()
pbar.clear()
compute_session = await session.ComputeSession.get_or_create(
image,
name=name,
)
await asyncio.sleep(0.1)
return True

if compute_session.status == 'PENDING':
print_info('Session ID {0} is enqueued for scheduling.'
.format(name))
return
result = await display_kernel_pulling(compute_session)
if not result:
return
elif compute_session.status == 'SCHEDULED':
print_info('Session ID {0} is scheduled and about to be started.'
.format(name))
Expand All @@ -623,7 +659,9 @@ async def _run(session, idx, name, envs,
elif compute_session.status == 'TIMEOUT':
print_info('Session ID {0} is still on the job queue.'
.format(name))
return
result = await display_kernel_pulling(compute_session)
if not result:
return
elif compute_session.status in ('ERROR', 'CANCELLED'):
print_fail('Session ID {0} has an error during scheduling/startup or cancelled.'
.format(name))
Expand Down
25 changes: 3 additions & 22 deletions src/ai/backend/client/func/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,28 +301,6 @@ async def get_or_create(
rqst.set_json(params)
async with rqst.fetch() as resp:
data = await resp.json()
if 'background_task' in data:
with tqdm(total=100, unit='%') as pbar:
task_id = data['background_task']
bgtask = resp.session.BackgroundTask(task_id)
async with bgtask.listen_events() as response:
async for ev in response:
progress = json.loads(ev.data)
if ev.event == 'bgtask_updated':
current = progress['current_progress']
total = progress['total_progress']
if total == 0:
total = 1e-2
pbar.n = round(current / total * 100, 2)
pbar.update(0)
pbar.refresh()
elif ev.event == 'bgtask_done':
pbar.n = 100.0
pbar.update(0)
pbar.refresh()
pbar.clear()
async with rqst.fetch() as resp:
data = await resp.json()
o = cls(name, owner_access_key) # type: ignore
if api_session.get().api_version[0] >= 5:
o.id = UUID(data['sessionId'])
Expand All @@ -331,6 +309,9 @@ async def get_or_create(
o.service_ports = data.get('servicePorts', [])
o.domain = domain_name
o.group = group_name
if 'background_task' in data:
task_id = data['background_task']
o.backgroundtask = resp.session.BackgroundTask(task_id)
return o

@api_function
Expand Down

0 comments on commit 171918b

Please sign in to comment.