Skip to content

Commit 13de4e6

Browse files
authored
[batch] Add Job Groups to Batch (#14282)
This PR adds the job groups functionality as described in this [RFC](hail-is/hail-rfcs#5) to the Batch backend and `hailtop.batch_client`. This includes supporting nested job groups up to a maximum depth of 5. Note, that none of these changes are user-facing yet (hence no change log here). The PRs that came before this one: - #13475 - #13487 - #13810 (note that this database migration required a shutdown) Subsequent PRs will need to implement the following: - Querying job groups with the flexible query language (v2) - Implementing job groups in the Scala Client for QoB - Using job groups in QoB with `cancel_after_n_failures=1` for all new stages of worker jobs - UI functionality to page and sort through job groups - A new `hailtop.batch` interface for users to define and work with Job Groups A couple of nuances in the implementation came up that I also tried to articulate in the RFC: 1. A root job group with ID = 0 does not belong to an update ("update_id" IS NULL). This means that any checks that look for "committed" job groups need to do `(batch_updates.committed OR job_groups.job_group_id = %s)` where "%s" is the ROOT_JOB_GROUP_ID. 2. When job groups are cancelled, only the specific job group that was cancelled is inserted into `job_groups_cancelled`. This table does **NOT** contain all transitive job groups that were also cancelled indirectly. The reason for this is we cannot guarantee that a user wouldn't have millions of job groups and we can't insert millions of records inside a single SQL stored procedure. Now, any query on the driver / front_end must look up the tree and see if any parent has been cancelled. This code looks similar to the code below [1]. 3. There used to be `DELETE FROM` statements in `commit_batch_update` and `commit_batch` that cleaned up old records that were no longer used in `job_group_inst_coll_cancellable_resources` and `job_groups_inst_coll_staging`. This cleanup now occurs in a periodic loop on the driver. 4. The `job_group_inst_coll_cancellable_resources` and `job_groups_inst_coll_staging` tables have values which represent the sum of all child job groups. For example, if a job group has 1 job and it's child job group has 2 jobs, then the staging table would have n_jobs = 3 for the parent job group and n_jobs = 2 for the child job group. Likewise, all of the billing triggers and MJC have to use the `job_group_self_and_ancestors` table to modify the job group the job belongs to as well its parent job groups. [1] Code to check whether a job group has been cancelled. ```mysql SELECT job_groups.*, cancelled_t.cancelled IS NOT NULL AS cancelled FROM job_groups LEFT JOIN LATERAL ( SELECT 1 AS cancelled FROM job_group_self_and_ancestors INNER JOIN job_groups_cancelled ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND job_groups.job_group_id = job_group_self_and_ancestors.job_group_id ) AS cancelled_t ON TRUE WHERE ... ```
1 parent ad60919 commit 13de4e6

27 files changed

+3664
-763
lines changed

batch/batch/batch.py

+90-38
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
import json
22
import logging
3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Optional, cast
44

55
from gear import transaction
6-
from hailtop.batch_client.types import CostBreakdownEntry, JobListEntryV1Alpha
6+
from hailtop.batch_client.globals import ROOT_JOB_GROUP_ID
7+
from hailtop.batch_client.types import CostBreakdownEntry, GetJobGroupResponseV1Alpha, JobListEntryV1Alpha
78
from hailtop.utils import humanize_timedelta_msecs, time_msecs_str
89

910
from .batch_format_version import BatchFormatVersion
10-
from .exceptions import NonExistentBatchError, OpenBatchError
11+
from .exceptions import NonExistentJobGroupError
1112
from .utils import coalesce
1213

1314
log = logging.getLogger('batch')
1415

1516

17+
def _maybe_time_msecs_str(t: Optional[int]) -> Optional[str]:
18+
if t is not None:
19+
return time_msecs_str(t)
20+
return None
21+
22+
1623
def cost_breakdown_to_dict(cost_breakdown: Dict[str, float]) -> List[CostBreakdownEntry]:
1724
return [{'resource': resource, 'cost': cost} for resource, cost in cost_breakdown.items()]
1825

@@ -30,14 +37,9 @@ def batch_record_to_dict(record: Dict[str, Any]) -> Dict[str, Any]:
3037
else:
3138
state = 'running'
3239

33-
def _time_msecs_str(t):
34-
if t:
35-
return time_msecs_str(t)
36-
return None
37-
38-
time_created = _time_msecs_str(record['time_created'])
39-
time_closed = _time_msecs_str(record['time_closed'])
40-
time_completed = _time_msecs_str(record['time_completed'])
40+
time_created = _maybe_time_msecs_str(record['time_created'])
41+
time_closed = _maybe_time_msecs_str(record['time_closed'])
42+
time_completed = _maybe_time_msecs_str(record['time_completed'])
4143

4244
if record['time_created'] and record['time_completed']:
4345
duration_ms = record['time_completed'] - record['time_created']
@@ -49,7 +51,7 @@ def _time_msecs_str(t):
4951
if record['cost_breakdown'] is not None:
5052
record['cost_breakdown'] = cost_breakdown_to_dict(json.loads(record['cost_breakdown']))
5153

52-
d = {
54+
batch_response = {
5355
'id': record['id'],
5456
'user': record['user'],
5557
'billing_project': record['billing_project'],
@@ -74,9 +76,55 @@ def _time_msecs_str(t):
7476

7577
attributes = json.loads(record['attributes'])
7678
if attributes:
77-
d['attributes'] = attributes
79+
batch_response['attributes'] = attributes
80+
81+
return batch_response
82+
83+
84+
def job_group_record_to_dict(record: Dict[str, Any]) -> GetJobGroupResponseV1Alpha:
85+
if record['n_failed'] > 0:
86+
state = 'failure'
87+
elif record['cancelled'] or record['n_cancelled'] > 0:
88+
state = 'cancelled'
89+
elif record['state'] == 'complete':
90+
assert record['n_succeeded'] == record['n_jobs']
91+
state = 'success'
92+
else:
93+
state = 'running'
7894

79-
return d
95+
time_created = _maybe_time_msecs_str(record['time_created'])
96+
time_completed = _maybe_time_msecs_str(record['time_completed'])
97+
98+
if record['time_created'] and record['time_completed']:
99+
duration_ms = record['time_completed'] - record['time_created']
100+
else:
101+
duration_ms = None
102+
103+
if record['cost_breakdown'] is not None:
104+
record['cost_breakdown'] = cost_breakdown_to_dict(json.loads(record['cost_breakdown']))
105+
106+
job_group_response = {
107+
'batch_id': record['batch_id'],
108+
'job_group_id': record['job_group_id'],
109+
'state': state,
110+
'complete': record['state'] == 'complete',
111+
'n_jobs': record['n_jobs'],
112+
'n_completed': record['n_completed'],
113+
'n_succeeded': record['n_succeeded'],
114+
'n_failed': record['n_failed'],
115+
'n_cancelled': record['n_cancelled'],
116+
'time_created': time_created,
117+
'time_completed': time_completed,
118+
'duration': duration_ms,
119+
'cost': coalesce(record['cost'], 0),
120+
'cost_breakdown': record['cost_breakdown'],
121+
}
122+
123+
attributes = json.loads(record['attributes'])
124+
if attributes:
125+
job_group_response['attributes'] = attributes
126+
127+
return cast(GetJobGroupResponseV1Alpha, job_group_response)
80128

81129

82130
def job_record_to_dict(record: Dict[str, Any], name: Optional[str]) -> JobListEntryV1Alpha:
@@ -93,38 +141,42 @@ def job_record_to_dict(record: Dict[str, Any], name: Optional[str]) -> JobListEn
93141
if record['cost_breakdown'] is not None:
94142
record['cost_breakdown'] = cost_breakdown_to_dict(json.loads(record['cost_breakdown']))
95143

96-
return {
97-
'batch_id': record['batch_id'],
98-
'job_id': record['job_id'],
99-
'name': name,
100-
'user': record['user'],
101-
'billing_project': record['billing_project'],
102-
'state': record['state'],
103-
'exit_code': exit_code,
104-
'duration': duration,
105-
'cost': coalesce(record['cost'], 0),
106-
'msec_mcpu': record['msec_mcpu'],
107-
'cost_breakdown': record['cost_breakdown'],
108-
}
109-
110-
111-
async def cancel_batch_in_db(db, batch_id):
144+
return cast(
145+
JobListEntryV1Alpha,
146+
{
147+
'batch_id': record['batch_id'],
148+
'job_id': record['job_id'],
149+
'name': name,
150+
'user': record['user'],
151+
'billing_project': record['billing_project'],
152+
'state': record['state'],
153+
'exit_code': exit_code,
154+
'duration': duration,
155+
'cost': coalesce(record['cost'], 0),
156+
'msec_mcpu': record['msec_mcpu'],
157+
'cost_breakdown': record['cost_breakdown'],
158+
},
159+
)
160+
161+
162+
async def cancel_job_group_in_db(db, batch_id, job_group_id):
112163
@transaction(db)
113164
async def cancel(tx):
114165
record = await tx.execute_and_fetchone(
115166
"""
116-
SELECT `state` FROM batches
117-
WHERE id = %s AND NOT deleted
167+
SELECT 1
168+
FROM job_groups
169+
LEFT JOIN batches ON batches.id = job_groups.batch_id
170+
LEFT JOIN batch_updates ON job_groups.batch_id = batch_updates.batch_id AND
171+
job_groups.update_id = batch_updates.update_id
172+
WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted AND (batch_updates.committed OR job_groups.job_group_id = %s)
118173
FOR UPDATE;
119174
""",
120-
(batch_id,),
175+
(batch_id, job_group_id, ROOT_JOB_GROUP_ID),
121176
)
122177
if not record:
123-
raise NonExistentBatchError(batch_id)
124-
125-
if record['state'] == 'open':
126-
raise OpenBatchError(batch_id)
178+
raise NonExistentJobGroupError(batch_id, job_group_id)
127179

128-
await tx.just_execute('CALL cancel_batch(%s);', (batch_id,))
180+
await tx.just_execute('CALL cancel_job_group(%s, %s);', (batch_id, job_group_id))
129181

130182
await cancel()

batch/batch/constants.py

-1
This file was deleted.

batch/batch/driver/canceller.py

+70-37
Original file line numberDiff line numberDiff line change
@@ -94,39 +94,44 @@ async def cancel_cancelled_ready_jobs_loop_body(self):
9494
}
9595

9696
async def user_cancelled_ready_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]:
97-
async for batch in self.db.select_and_fetchall(
97+
async for job_group in self.db.select_and_fetchall(
9898
"""
99-
SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled
100-
FROM batches
101-
LEFT JOIN job_groups_cancelled
102-
ON batches.id = job_groups_cancelled.id
99+
SELECT job_groups.batch_id, job_groups.job_group_id, t.cancelled IS NOT NULL AS cancelled
100+
FROM job_groups
101+
LEFT JOIN LATERAL (
102+
SELECT 1 AS cancelled
103+
FROM job_group_self_and_ancestors
104+
INNER JOIN job_groups_cancelled
105+
ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND
106+
job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id
107+
WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND
108+
job_groups.job_group_id = job_group_self_and_ancestors.job_group_id
109+
) AS t ON TRUE
103110
WHERE user = %s AND `state` = 'running';
104111
""",
105112
(user,),
106113
):
107-
if batch['cancelled']:
114+
if job_group['cancelled']:
108115
async for record in self.db.select_and_fetchall(
109116
"""
110-
SELECT jobs.job_id
117+
SELECT jobs.batch_id, jobs.job_id, jobs.job_group_id
111118
FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled)
112-
WHERE batch_id = %s AND state = 'Ready' AND always_run = 0
119+
WHERE batch_id = %s AND job_group_id = %s AND state = 'Ready' AND always_run = 0
113120
LIMIT %s;
114121
""",
115-
(batch['id'], remaining.value),
122+
(job_group['batch_id'], job_group['job_group_id'], remaining.value),
116123
):
117-
record['batch_id'] = batch['id']
118124
yield record
119125
else:
120126
async for record in self.db.select_and_fetchall(
121127
"""
122-
SELECT jobs.job_id
128+
SELECT jobs.batch_id, jobs.job_id, jobs.job_group_id
123129
FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled)
124-
WHERE batch_id = %s AND state = 'Ready' AND always_run = 0 AND cancelled = 1
130+
WHERE batch_id = %s AND job_group_id = %s AND state = 'Ready' AND always_run = 0 AND cancelled = 1
125131
LIMIT %s;
126132
""",
127-
(batch['id'], remaining.value),
133+
(job_group['batch_id'], job_group['job_group_id'], remaining.value),
128134
):
129-
record['batch_id'] = batch['id']
130135
yield record
131136

132137
waitable_pool = WaitableSharedPool(self.async_worker_pool)
@@ -137,18 +142,30 @@ async def user_cancelled_ready_jobs(user, remaining) -> AsyncIterator[Dict[str,
137142
async for record in user_cancelled_ready_jobs(user, remaining):
138143
batch_id = record['batch_id']
139144
job_id = record['job_id']
145+
job_group_id = record['job_group_id']
140146
id = (batch_id, job_id)
141147
log.info(f'cancelling job {id}')
142148

143-
async def cancel_with_error_handling(app, batch_id, job_id, id):
149+
async def cancel_with_error_handling(app, batch_id, job_id, job_group_id, id):
144150
try:
145151
await mark_job_complete(
146-
app, batch_id, job_id, None, None, 'Cancelled', None, None, None, 'cancelled', []
152+
app,
153+
batch_id,
154+
job_id,
155+
None,
156+
job_group_id,
157+
None,
158+
'Cancelled',
159+
None,
160+
None,
161+
None,
162+
'cancelled',
163+
[],
147164
)
148165
except Exception:
149166
log.info(f'error while cancelling job {id}', exc_info=True)
150167

151-
await waitable_pool.call(cancel_with_error_handling, self.app, batch_id, job_id, id)
168+
await waitable_pool.call(cancel_with_error_handling, self.app, batch_id, job_id, job_group_id, id)
152169

153170
remaining.value -= 1
154171
if remaining.value <= 0:
@@ -182,28 +199,34 @@ async def cancel_cancelled_creating_jobs_loop_body(self):
182199
}
183200

184201
async def user_cancelled_creating_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]:
185-
async for batch in self.db.select_and_fetchall(
202+
async for job_group in self.db.select_and_fetchall(
186203
"""
187-
SELECT batches.id
188-
FROM batches
189-
INNER JOIN job_groups_cancelled
190-
ON batches.id = job_groups_cancelled.id
204+
SELECT job_groups.batch_id, job_groups.job_group_id
205+
FROM job_groups
206+
INNER JOIN LATERAL (
207+
SELECT 1 AS cancelled
208+
FROM job_group_self_and_ancestors
209+
INNER JOIN job_groups_cancelled
210+
ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND
211+
job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id
212+
WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND
213+
job_groups.job_group_id = job_group_self_and_ancestors.job_group_id
214+
) AS t ON TRUE
191215
WHERE user = %s AND `state` = 'running';
192216
""",
193217
(user,),
194218
):
195219
async for record in self.db.select_and_fetchall(
196220
"""
197-
SELECT jobs.job_id, attempts.attempt_id, attempts.instance_name
221+
SELECT jobs.batch_id, jobs.job_id, attempts.attempt_id, attempts.instance_name, jobs.job_group_id
198222
FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled)
199223
STRAIGHT_JOIN attempts
200224
ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id
201-
WHERE jobs.batch_id = %s AND state = 'Creating' AND always_run = 0 AND cancelled = 0
225+
WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND state = 'Creating' AND always_run = 0 AND cancelled = 0
202226
LIMIT %s;
203227
""",
204-
(batch['id'], remaining.value),
228+
(job_group['batch_id'], job_group['job_group_id'], remaining.value),
205229
):
206-
record['batch_id'] = batch['id']
207230
yield record
208231

209232
waitable_pool = WaitableSharedPool(self.async_worker_pool)
@@ -215,17 +238,21 @@ async def user_cancelled_creating_jobs(user, remaining) -> AsyncIterator[Dict[st
215238
batch_id = record['batch_id']
216239
job_id = record['job_id']
217240
attempt_id = record['attempt_id']
241+
job_group_id = record['job_group_id']
218242
instance_name = record['instance_name']
219243
id = (batch_id, job_id)
220244

221-
async def cancel_with_error_handling(app, batch_id, job_id, attempt_id, instance_name, id):
245+
async def cancel_with_error_handling(
246+
app, batch_id, job_id, attempt_id, job_group_id, instance_name, id
247+
):
222248
try:
223249
end_time = time_msecs()
224250
await mark_job_complete(
225251
app,
226252
batch_id,
227253
job_id,
228254
attempt_id,
255+
job_group_id,
229256
instance_name,
230257
'Cancelled',
231258
None,
@@ -246,7 +273,7 @@ async def cancel_with_error_handling(app, batch_id, job_id, attempt_id, instance
246273
log.info(f'cancelling creating job {id} on instance {instance_name}', exc_info=True)
247274

248275
await waitable_pool.call(
249-
cancel_with_error_handling, self.app, batch_id, job_id, attempt_id, instance_name, id
276+
cancel_with_error_handling, self.app, batch_id, job_id, attempt_id, job_group_id, instance_name, id
250277
)
251278

252279
remaining.value -= 1
@@ -279,28 +306,34 @@ async def cancel_cancelled_running_jobs_loop_body(self):
279306
}
280307

281308
async def user_cancelled_running_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]:
282-
async for batch in self.db.select_and_fetchall(
309+
async for job_group in self.db.select_and_fetchall(
283310
"""
284-
SELECT batches.id
285-
FROM batches
286-
INNER JOIN job_groups_cancelled
287-
ON batches.id = job_groups_cancelled.id
311+
SELECT job_groups.batch_id, job_groups.job_group_id
312+
FROM job_groups
313+
INNER JOIN LATERAL (
314+
SELECT 1 AS cancelled
315+
FROM job_group_self_and_ancestors
316+
INNER JOIN job_groups_cancelled
317+
ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND
318+
job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id
319+
WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND
320+
job_groups.job_group_id = job_group_self_and_ancestors.job_group_id
321+
) AS t ON TRUE
288322
WHERE user = %s AND `state` = 'running';
289323
""",
290324
(user,),
291325
):
292326
async for record in self.db.select_and_fetchall(
293327
"""
294-
SELECT jobs.job_id, attempts.attempt_id, attempts.instance_name
328+
SELECT jobs.batch_id, jobs.job_id, attempts.attempt_id, attempts.instance_name
295329
FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled)
296330
STRAIGHT_JOIN attempts
297331
ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id
298-
WHERE jobs.batch_id = %s AND state = 'Running' AND always_run = 0 AND cancelled = 0
332+
WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND state = 'Running' AND always_run = 0 AND cancelled = 0
299333
LIMIT %s;
300334
""",
301-
(batch['id'], remaining.value),
335+
(job_group['batch_id'], job_group['job_group_id'], remaining.value),
302336
):
303-
record['batch_id'] = batch['id']
304337
yield record
305338

306339
waitable_pool = WaitableSharedPool(self.async_worker_pool)

0 commit comments

Comments
 (0)