Skip to content

Commit

Permalink
[batch] Add job group in client and capability to list and get job gr…
Browse files Browse the repository at this point in the history
…oups
  • Loading branch information
jigold committed Nov 16, 2023
1 parent e13e83a commit 6e0b358
Show file tree
Hide file tree
Showing 9 changed files with 578 additions and 125 deletions.
59 changes: 56 additions & 3 deletions batch/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from hailtop.utils import humanize_timedelta_msecs, time_msecs_str

from .batch_format_version import BatchFormatVersion
from .constants import ROOT_JOB_GROUP_ID
from .exceptions import NonExistentBatchError, OpenBatchError
from .utils import coalesce

Expand Down Expand Up @@ -80,6 +79,60 @@ def _time_msecs_str(t):
return d


def job_group_record_to_dict(record: Dict[str, Any]) -> Dict[str, Any]:
if record['n_failed'] > 0:
state = 'failure'
elif record['cancelled'] or record['n_cancelled'] > 0:
state = 'cancelled'
elif record['state'] == 'complete':
assert record['n_succeeded'] == record['n_jobs']
state = 'success'
else:
state = 'running'

def _time_msecs_str(t):
if t:
return time_msecs_str(t)
return None

time_created = _time_msecs_str(record['time_created'])
time_completed = _time_msecs_str(record['time_completed'])

if record['time_created'] and record['time_completed']:
duration_ms = record['time_completed'] - record['time_created']
duration = humanize_timedelta_msecs(duration_ms)
else:
duration_ms = None
duration = None

if record['cost_breakdown'] is not None:
record['cost_breakdown'] = cost_breakdown_to_dict(json.loads(record['cost_breakdown']))

d = {
'batch_id': record['batch_id'],
'job_group_id': record['job_group_id'],
'state': state,
'complete': record['state'] == 'complete',
'n_jobs': record['n_jobs'],
'n_completed': record['n_completed'],
'n_succeeded': record['n_succeeded'],
'n_failed': record['n_failed'],
'n_cancelled': record['n_cancelled'],
'time_created': time_created,
'time_completed': time_completed,
'duration_ms': duration_ms,
'duration': duration,
'cost': coalesce(record['cost'], 0),
'cost_breakdown': record['cost_breakdown'],
}

attributes = json.loads(record['attributes'])
if attributes:
d['attributes'] = attributes

return d


def job_record_to_dict(record: Dict[str, Any], name: Optional[str]) -> JobListEntryV1Alpha:
format_version = BatchFormatVersion(record['format_version'])

Expand Down Expand Up @@ -109,7 +162,7 @@ def job_record_to_dict(record: Dict[str, Any], name: Optional[str]) -> JobListEn
}


async def cancel_batch_in_db(db, batch_id):
async def cancel_job_group_in_db(db, batch_id, job_group_id):
@transaction(db)
async def cancel(tx):
record = await tx.execute_and_fetchone(
Expand All @@ -126,6 +179,6 @@ async def cancel(tx):
if record['state'] == 'open':
raise OpenBatchError(batch_id)

await tx.just_execute('CALL cancel_job_group(%s, %s);', (batch_id, ROOT_JOB_GROUP_ID))
await tx.just_execute('CALL cancel_job_group(%s, %s);', (batch_id, job_group_id))

await cancel()
4 changes: 2 additions & 2 deletions batch/batch/driver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from web_common import render_template, set_message, setup_aiohttp_jinja2, setup_common_static_routes

from ..batch import cancel_batch_in_db
from ..batch import cancel_job_group_in_db
from ..batch_configuration import (
BATCH_STORAGE_URI,
CLOUD,
Expand Down Expand Up @@ -1227,7 +1227,7 @@ async def check(tx):

async def _cancel_batch(app, batch_id):
try:
await cancel_batch_in_db(app['db'], batch_id)
await cancel_job_group_in_db(app['db'], batch_id, ROOT_JOB_GROUP_ID)
except BatchUserError as exc:
log.info(f'cannot cancel batch because {exc.message}')
return
Expand Down
170 changes: 142 additions & 28 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
from web_common import render_template, set_message, setup_aiohttp_jinja2, setup_common_static_routes

from ..batch import batch_record_to_dict, cancel_batch_in_db, job_record_to_dict
from ..batch import batch_record_to_dict, cancel_job_group_in_db, job_group_record_to_dict, job_record_to_dict
from ..batch_configuration import BATCH_STORAGE_URI, CLOUD, DEFAULT_NAMESPACE, SCOPE
from ..batch_format_version import BatchFormatVersion
from ..cloud.resource_utils import (
Expand Down Expand Up @@ -104,12 +104,17 @@
)
from .query import (
CURRENT_QUERY_VERSION,
parse_batch_jobs_query_v1,
parse_batch_jobs_query_v2,
parse_job_group_jobs_query_v1,
parse_list_batches_query_v1,
parse_list_batches_query_v2,
)
from .validate import ValidationError, validate_and_clean_jobs, validate_batch, validate_batch_update
from .validate import (
ValidationError,
validate_and_clean_jobs,
validate_batch,
validate_batch_update,
)

uvloop.install()

Expand Down Expand Up @@ -198,6 +203,13 @@ def cast_query_param_to_int(param: Optional[str]) -> Optional[int]:
return None


def cast_query_param_to_bool(param: Optional[str]) -> bool:
if param in ('False', 'false', '0'):
return False
assert param in ('True', 'true', '1')
return True


@routes.get('/healthcheck')
async def get_healthcheck(_) -> web.Response:
return web.Response()
Expand Down Expand Up @@ -248,15 +260,21 @@ async def _handle_api_error(f: Callable[P, Awaitable[T]], *args: P.args, **kwarg
raise e.http_response()


async def _query_batch_jobs(
request: web.Request, batch_id: int, version: int, q: str, last_job_id: Optional[int]
async def _query_job_group_jobs(
request: web.Request,
batch_id: int,
job_group_id: int,
version: int,
q: str,
last_job_id: Optional[int],
recursive: bool,
) -> Tuple[List[JobListEntryV1Alpha], Optional[int]]:
db: Database = request.app['db']
if version == 1:
sql, sql_args = parse_batch_jobs_query_v1(batch_id, q, last_job_id)
sql, sql_args = parse_job_group_jobs_query_v1(batch_id, job_group_id, q, last_job_id, recursive)
else:
assert version == 2, version
sql, sql_args = parse_batch_jobs_query_v2(batch_id, q, last_job_id)
sql, sql_args = parse_batch_jobs_query_v2(batch_id, job_group_id, q, last_job_id, recursive)

jobs = [job_record_to_dict(record, record['name']) async for record in db.select_and_fetchall(sql, sql_args)]

Expand All @@ -269,7 +287,13 @@ async def _query_batch_jobs(


async def _get_jobs(
request: web.Request, batch_id: int, version: int, q: str, last_job_id: Optional[int]
request: web.Request,
batch_id: int,
job_group_id: int,
version: int,
q: str,
last_job_id: Optional[int],
recursive: bool,
) -> GetJobsResponseV1Alpha:
db = request.app['db']

Expand All @@ -283,7 +307,7 @@ async def _get_jobs(
if not record:
raise web.HTTPNotFound()

jobs, last_job_id = await _query_batch_jobs(request, batch_id, version, q, last_job_id)
jobs, last_job_id = await _query_job_group_jobs(request, batch_id, job_group_id, version, q, last_job_id, recursive)

if last_job_id is not None:
return {'jobs': jobs, 'last_job_id': last_job_id}
Expand All @@ -293,21 +317,38 @@ async def _get_jobs(
@routes.get('/api/v1alpha/batches/{batch_id}/jobs')
@billing_project_users_only()
@add_metadata_to_request
async def get_jobs_v1(request: web.Request, _, batch_id: int) -> web.Response:
q = request.query.get('q', '')
last_job_id = cast_query_param_to_int(request.query.get('last_job_id'))
resp = await _handle_api_error(_get_jobs, request, batch_id, 1, q, last_job_id)
assert resp is not None
return json_response(resp)
async def get_batch_jobs_v1(request: web.Request, _, batch_id: int) -> web.Response:
return await _get_job_group_jobs(request, batch_id, ROOT_JOB_GROUP_ID, 1)


@routes.get('/api/v2alpha/batches/{batch_id}/jobs')
@billing_project_users_only()
@add_metadata_to_request
async def get_jobs_v2(request: web.Request, _, batch_id: int) -> web.Response:
async def get_batch_jobs_v2(request: web.Request, _, batch_id: int) -> web.Response:
return await _get_job_group_jobs(request, batch_id, ROOT_JOB_GROUP_ID, 2)


@routes.get('/api/v1alpha/batches/{batch_id}/job-group/{job_group_id}/jobs')
@billing_project_users_only()
@add_metadata_to_request
async def get_job_group_jobs_v1(request: web.Request, _, batch_id: int) -> web.Response:
job_group_id = int(request.match_info['job_group_id'])
return await _get_job_group_jobs(request, batch_id, job_group_id, 1)


@routes.get('/api/v2alpha/batches/{batch_id}/job-group/{job_group_id}/jobs')
@billing_project_users_only()
@add_metadata_to_request
async def get_job_group_jobs_v2(request: web.Request, _, batch_id: int) -> web.Response:
job_group_id = int(request.match_info['job_group_id'])
return await _get_job_group_jobs(request, batch_id, job_group_id, 2)


async def _get_job_group_jobs(request, batch_id: int, job_group_id: int, version: int):
q = request.query.get('q', '')
recursive = cast_query_param_to_bool(request.query.get('recursive'))
last_job_id = cast_query_param_to_int(request.query.get('last_job_id'))
resp = await _handle_api_error(_get_jobs, request, batch_id, 2, q, last_job_id)
resp = await _handle_api_error(_get_jobs, request, batch_id, job_group_id, version, q, last_job_id, recursive)
assert resp is not None
return json_response(resp)

Expand Down Expand Up @@ -1180,6 +1221,18 @@ async def write_and_insert(tx):
return web.Response()


def root_job_group_spec(batch_spec: dict):
return {
'job_group_id': ROOT_JOB_GROUP_ID,
'attributes': batch_spec.get('attributes'),
'cancel_after_n_failures': batch_spec.get('cancel_after_n_failures'),
'callback': batch_spec.get('callback'),
'n_jobs': batch_spec['n_jobs'],
'absolute_parent_id': None,
'in_update_parent_id': None,
}


@routes.post('/api/v1alpha/batches/create-fast')
@auth.authenticated_users_only()
@add_metadata_to_request
Expand Down Expand Up @@ -1220,6 +1273,7 @@ async def create_batch(request, userdata):
)
else:
update_id = None

request['batch_telemetry']['batch_id'] = str(id)
return json_response({'id': id, 'update_id': update_id})

Expand Down Expand Up @@ -1402,8 +1456,11 @@ async def update_batch_fast(request, userdata):
if f'update {update_id} is already committed' == e.reason:
return json_response({'update_id': update_id, 'start_job_id': start_job_id})
raise

await _commit_update(app, batch_id, update_id, user, db)

request['batch_telemetry']['batch_id'] = str(batch_id)

return json_response({'update_id': update_id, 'start_job_id': start_job_id})


Expand Down Expand Up @@ -1539,8 +1596,47 @@ async def _get_batch(app, batch_id):
return batch_record_to_dict(record)


async def _cancel_batch(app, batch_id):
await cancel_batch_in_db(app['db'], batch_id)
async def _get_job_group(app, batch_id: int, job_group_id: int):
db: Database = app['db']

record = await db.select_and_fetchone(
'''
SELECT job_groups.*,
job_groups_cancelled.id IS NOT NULL AS cancelled,
job_groups_n_jobs_in_complete_states.n_completed,
job_groups_n_jobs_in_complete_states.n_succeeded,
job_groups_n_jobs_in_complete_states.n_failed,
job_groups_n_jobs_in_complete_states.n_cancelled,
cost_t.*
FROM job_groups
LEFT JOIN batches ON batches.id = job_groups.batch_id
LEFT JOIN job_groups_n_jobs_in_complete_states
ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id
LEFT JOIN job_groups_cancelled
ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id
LEFT JOIN LATERAL (
SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown
FROM (
SELECT batch_id, job_group_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage`
FROM aggregated_job_group_resources_v3
WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id
GROUP BY batch_id, job_group_id, resource_id
) AS usage_t
LEFT JOIN resources ON usage_t.resource_id = resources.resource_id
GROUP BY batch_id, job_group_id
) AS cost_t ON TRUE
WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted;
''',
(batch_id, job_group_id),
)
if not record:
raise web.HTTPNotFound()

return job_group_record_to_dict(record)


async def _cancel_job_group(app, batch_id, job_group_id):
await cancel_job_group_in_db(app['db'], batch_id, job_group_id)
app['cancel_batch_state_changed'].set()
return web.Response()

Expand Down Expand Up @@ -1576,7 +1672,24 @@ async def get_batch(request: web.Request, _, batch_id: int) -> web.Response:
@billing_project_users_only()
@add_metadata_to_request
async def cancel_batch(request: web.Request, _, batch_id: int) -> web.Response:
await _handle_api_error(_cancel_batch, request.app, batch_id)
await _handle_api_error(_cancel_job_group, request.app, batch_id, ROOT_JOB_GROUP_ID)
return web.Response()


@routes.get('/api/v1alpha/batches/{batch_id}/job-groups/{job_group_id}')
@billing_project_users_only()
@add_metadata_to_request
async def get_job_group(request: web.Request, _, batch_id: int) -> web.Response:
job_group_id = int(request.match_info['job_group_id'])
return json_response(await _get_job_group(request.app, batch_id, job_group_id))


@routes.patch('/api/v1alpha/batches/{batch_id}/job-groups/{job_group_id}/cancel')
@billing_project_users_only()
@add_metadata_to_request
async def cancel_job_group(request: web.Request, _, batch_id: int) -> web.Response:
job_group_id = int(request.match_info['job_group_id'])
await _handle_api_error(_cancel_job_group, request.app, batch_id, job_group_id)
return web.Response()


Expand Down Expand Up @@ -1631,13 +1744,12 @@ async def commit_update(request: web.Request, userdata):
record = await db.select_and_fetchone(
'''
SELECT start_job_id, job_groups_cancelled.id IS NOT NULL AS cancelled
FROM job_groups
LEFT JOIN batches ON job_groups.batch_id = batches.id
LEFT JOIN batch_updates ON job_groups.batch_id = batch_updates.batch_id
LEFT JOIN job_groups_cancelled ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id
WHERE job_groups.user = %s AND job_groups.batch_id = %s AND job_groups.job_group_id = %s AND batch_updates.update_id = %s AND NOT deleted;
FROM batches
LEFT JOIN batch_updates ON batches.id = batch_updates.batch_id
LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id AND job_groups_cancelled.job_group_id = %s
WHERE batches.user = %s AND batches.id = %s AND batch_updates.update_id = %s AND NOT deleted;
''',
(user, batch_id, ROOT_JOB_GROUP_ID, update_id),
(ROOT_JOB_GROUP_ID, user, batch_id, update_id),
)
if not record:
raise web.HTTPNotFound()
Expand Down Expand Up @@ -1692,7 +1804,9 @@ async def ui_batch(request, userdata, batch_id):
last_job_id = cast_query_param_to_int(request.query.get('last_job_id'))

try:
jobs, last_job_id = await _query_batch_jobs(request, batch_id, CURRENT_QUERY_VERSION, q, last_job_id)
jobs, last_job_id = await _query_job_group_jobs(
request, batch_id, ROOT_JOB_GROUP_ID, CURRENT_QUERY_VERSION, q, last_job_id, recursive=True
)
except QueryError as e:
session = await aiohttp_session.get_session(request)
set_message(session, e.message, 'error')
Expand Down Expand Up @@ -1730,7 +1844,7 @@ async def ui_cancel_batch(request: web.Request, _, batch_id: int) -> NoReturn:
params['q'] = str(q)
session = await aiohttp_session.get_session(request)
try:
await _handle_ui_error(session, _cancel_batch, request.app, batch_id)
await _handle_ui_error(session, _cancel_job_group, request.app, batch_id, ROOT_JOB_GROUP_ID)
set_message(session, f'Batch {batch_id} cancelled.', 'info')
finally:
location = request.app.router['batches'].url_for().with_query(params)
Expand Down
Loading

0 comments on commit 6e0b358

Please sign in to comment.