diff --git a/changes/181.feature b/changes/181.feature new file mode 100644 index 00000000..7e3d96a3 --- /dev/null +++ b/changes/181.feature @@ -0,0 +1 @@ +Display kernel-pull-progress from background-task-reporter diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index a5f524cc..bfeead45 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -298,6 +298,27 @@ async def get_or_create( rqst.set_json(params) async with rqst.fetch() as resp: data = await resp.json() + with tqdm(total=100, unit='%') as pbar: + task_id = data['background_task'] + bgtask = resp.session.BackgroundTask(task_id) + async with bgtask.listen_events() as response: + async for ev in response: + progress = json.loads(ev.data) + if ev.event == 'bgtask_updated': + current = progress['current_progress'] + total = progress['total_progress'] + if total == 0: + total = 1e-2 + pbar.n = round(current / total * 100, 2) + pbar.update(0) + pbar.refresh() + elif ev.event == 'bgtask_done': + pbar.n = 100.0 + pbar.update(0) + pbar.refresh() + pbar.clear() + async with rqst.fetch() as resp: + data = await resp.json() o = cls(name, owner_access_key) # type: ignore if api_session.get().api_version[0] >= 5: o.id = UUID(data['sessionId'])