Skip to content

Commit f73747b

Browse files
committed
Merge upstream HEAD(13de4e6, 2024-05-14) Add job groups [migration might take a while]
2 parents 4fe048f + 13de4e6 commit f73747b

File tree

80 files changed

+4505
-1377
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+4505
-1377
lines changed

auth/auth/auth.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import kubernetes_asyncio.client
1212
import kubernetes_asyncio.client.rest
1313
import kubernetes_asyncio.config
14-
import uvloop
1514
from aiohttp import web
1615
from prometheus_async.aio.web import server_stats # type: ignore
1716

@@ -33,7 +32,7 @@
3332
from gear.auth import AIOHTTPHandler, get_session_id
3433
from gear.cloud_config import get_global_config
3534
from gear.profiling import install_profiler_if_requested
36-
from hailtop import httpx
35+
from hailtop import httpx, uvloopx
3736
from hailtop.auth import AzureFlow, Flow, GoogleFlow, IdentityProvider
3837
from hailtop.config import get_deploy_config
3938
from hailtop.hail_logging import AccessLogger
@@ -56,8 +55,6 @@
5655

5756
log = logging.getLogger('auth')
5857

59-
uvloop.install()
60-
6158
CLOUD = get_global_config()['cloud']
6259
DEFAULT_NAMESPACE = os.environ['HAIL_DEFAULT_NAMESPACE']
6360

@@ -842,6 +839,7 @@ async def on_startup(app):
842839
kubernetes_asyncio.config.load_incluster_config()
843840
app[AppKeys.K8S_CLIENT] = kubernetes_asyncio.client.CoreV1Api()
844841
exit_stack.push_async_callback(app[AppKeys.K8S_CLIENT].api_client.rest_client.pool_manager.close)
842+
845843
app[AppKeys.K8S_CACHE] = K8sCache(app[AppKeys.K8S_CLIENT])
846844

847845

@@ -886,6 +884,8 @@ async def auth_check_csrf_token(request: web.Request, handler: AIOHTTPHandler):
886884

887885

888886
def run():
887+
uvloopx.install()
888+
889889
install_profiler_if_requested('auth')
890890

891891
app = web.Application(middlewares=[auth_check_csrf_token, monitor_endpoints_middleware])

batch/batch/batch.py

+90-38
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
import json
22
import logging
3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Optional, cast
44

55
from gear import transaction
6-
from hailtop.batch_client.types import CostBreakdownEntry, JobListEntryV1Alpha
6+
from hailtop.batch_client.globals import ROOT_JOB_GROUP_ID
7+
from hailtop.batch_client.types import CostBreakdownEntry, GetJobGroupResponseV1Alpha, JobListEntryV1Alpha
78
from hailtop.utils import humanize_timedelta_msecs, time_msecs_str
89

910
from .batch_format_version import BatchFormatVersion
10-
from .exceptions import NonExistentBatchError, OpenBatchError
11+
from .exceptions import NonExistentJobGroupError
1112
from .utils import coalesce
1213

1314
log = logging.getLogger('batch')
1415

1516

17+
def _maybe_time_msecs_str(t: Optional[int]) -> Optional[str]:
18+
if t is not None:
19+
return time_msecs_str(t)
20+
return None
21+
22+
1623
def cost_breakdown_to_dict(cost_breakdown: Dict[str, float]) -> List[CostBreakdownEntry]:
1724
return [{'resource': resource, 'cost': cost} for resource, cost in cost_breakdown.items()]
1825

@@ -30,14 +37,9 @@ def batch_record_to_dict(record: Dict[str, Any]) -> Dict[str, Any]:
3037
else:
3138
state = 'running'
3239

33-
def _time_msecs_str(t):
34-
if t:
35-
return time_msecs_str(t)
36-
return None
37-
38-
time_created = _time_msecs_str(record['time_created'])
39-
time_closed = _time_msecs_str(record['time_closed'])
40-
time_completed = _time_msecs_str(record['time_completed'])
40+
time_created = _maybe_time_msecs_str(record['time_created'])
41+
time_closed = _maybe_time_msecs_str(record['time_closed'])
42+
time_completed = _maybe_time_msecs_str(record['time_completed'])
4143

4244
if record['time_created'] and record['time_completed']:
4345
duration_ms = record['time_completed'] - record['time_created']
@@ -50,7 +52,7 @@ def _time_msecs_str(t):
5052
if cost_breakdown is not None:
5153
cost_breakdown = cost_breakdown_to_dict(json.loads(cost_breakdown))
5254

53-
d = {
55+
batch_response = {
5456
'id': record['id'],
5557
'user': record['user'],
5658
'billing_project': record['billing_project'],
@@ -75,9 +77,55 @@ def _time_msecs_str(t):
7577

7678
attributes = json.loads(record['attributes'])
7779
if attributes:
78-
d['attributes'] = attributes
80+
batch_response['attributes'] = attributes
81+
82+
return batch_response
83+
84+
85+
def job_group_record_to_dict(record: Dict[str, Any]) -> GetJobGroupResponseV1Alpha:
86+
if record['n_failed'] > 0:
87+
state = 'failure'
88+
elif record['cancelled'] or record['n_cancelled'] > 0:
89+
state = 'cancelled'
90+
elif record['state'] == 'complete':
91+
assert record['n_succeeded'] == record['n_jobs']
92+
state = 'success'
93+
else:
94+
state = 'running'
95+
96+
time_created = _maybe_time_msecs_str(record['time_created'])
97+
time_completed = _maybe_time_msecs_str(record['time_completed'])
7998

80-
return d
99+
if record['time_created'] and record['time_completed']:
100+
duration_ms = record['time_completed'] - record['time_created']
101+
else:
102+
duration_ms = None
103+
104+
if record['cost_breakdown'] is not None:
105+
record['cost_breakdown'] = cost_breakdown_to_dict(json.loads(record['cost_breakdown']))
106+
107+
job_group_response = {
108+
'batch_id': record['batch_id'],
109+
'job_group_id': record['job_group_id'],
110+
'state': state,
111+
'complete': record['state'] == 'complete',
112+
'n_jobs': record['n_jobs'],
113+
'n_completed': record['n_completed'],
114+
'n_succeeded': record['n_succeeded'],
115+
'n_failed': record['n_failed'],
116+
'n_cancelled': record['n_cancelled'],
117+
'time_created': time_created,
118+
'time_completed': time_completed,
119+
'duration': duration_ms,
120+
'cost': coalesce(record['cost'], 0),
121+
'cost_breakdown': record['cost_breakdown'],
122+
}
123+
124+
attributes = json.loads(record['attributes'])
125+
if attributes:
126+
job_group_response['attributes'] = attributes
127+
128+
return cast(GetJobGroupResponseV1Alpha, job_group_response)
81129

82130

83131
def job_record_to_dict(record: Dict[str, Any], name: Optional[str]) -> JobListEntryV1Alpha:
@@ -95,38 +143,42 @@ def job_record_to_dict(record: Dict[str, Any], name: Optional[str]) -> JobListEn
95143
if cost_breakdown is not None:
96144
cost_breakdown = cost_breakdown_to_dict(json.loads(cost_breakdown))
97145

98-
return {
99-
'batch_id': record['batch_id'],
100-
'job_id': record['job_id'],
101-
'name': name,
102-
'user': record['user'],
103-
'billing_project': record['billing_project'],
104-
'state': record['state'],
105-
'exit_code': exit_code,
106-
'duration': duration,
107-
'cost': coalesce(record.get('cost'), 0),
108-
'msec_mcpu': record['msec_mcpu'],
109-
'cost_breakdown': cost_breakdown,
110-
}
111-
112-
113-
async def cancel_batch_in_db(db, batch_id):
146+
return cast(
147+
JobListEntryV1Alpha,
148+
{
149+
'batch_id': record['batch_id'],
150+
'job_id': record['job_id'],
151+
'name': name,
152+
'user': record['user'],
153+
'billing_project': record['billing_project'],
154+
'state': record['state'],
155+
'exit_code': exit_code,
156+
'duration': duration,
157+
'cost': coalesce(record['cost'], 0),
158+
'msec_mcpu': record['msec_mcpu'],
159+
'cost_breakdown': cost_breakdown,
160+
},
161+
)
162+
163+
164+
async def cancel_job_group_in_db(db, batch_id, job_group_id):
114165
@transaction(db)
115166
async def cancel(tx):
116167
record = await tx.execute_and_fetchone(
117168
"""
118-
SELECT `state` FROM batches
119-
WHERE id = %s AND NOT deleted
169+
SELECT 1
170+
FROM job_groups
171+
LEFT JOIN batches ON batches.id = job_groups.batch_id
172+
LEFT JOIN batch_updates ON job_groups.batch_id = batch_updates.batch_id AND
173+
job_groups.update_id = batch_updates.update_id
174+
WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted AND (batch_updates.committed OR job_groups.job_group_id = %s)
120175
FOR UPDATE;
121176
""",
122-
(batch_id,),
177+
(batch_id, job_group_id, ROOT_JOB_GROUP_ID),
123178
)
124179
if not record:
125-
raise NonExistentBatchError(batch_id)
126-
127-
if record['state'] == 'open':
128-
raise OpenBatchError(batch_id)
180+
raise NonExistentJobGroupError(batch_id, job_group_id)
129181

130-
await tx.just_execute('CALL cancel_batch(%s);', (batch_id,))
182+
await tx.just_execute('CALL cancel_job_group(%s, %s);', (batch_id, job_group_id))
131183

132184
await cancel()

batch/batch/constants.py

-1
This file was deleted.

0 commit comments

Comments
 (0)