1
1
import json
2
2
import logging
3
- from typing import Any , Dict , List , Optional
3
+ from typing import Any , Dict , List , Optional , cast
4
4
5
5
from gear import transaction
6
- from hailtop .batch_client .types import CostBreakdownEntry , JobListEntryV1Alpha
6
+ from hailtop .batch_client .globals import ROOT_JOB_GROUP_ID
7
+ from hailtop .batch_client .types import CostBreakdownEntry , GetJobGroupResponseV1Alpha , JobListEntryV1Alpha
7
8
from hailtop .utils import humanize_timedelta_msecs , time_msecs_str
8
9
9
10
from .batch_format_version import BatchFormatVersion
10
- from .exceptions import NonExistentBatchError , OpenBatchError
11
+ from .exceptions import NonExistentJobGroupError
11
12
from .utils import coalesce
12
13
13
14
log = logging .getLogger ('batch' )
14
15
15
16
17
+ def _maybe_time_msecs_str (t : Optional [int ]) -> Optional [str ]:
18
+ if t is not None :
19
+ return time_msecs_str (t )
20
+ return None
21
+
22
+
16
23
def cost_breakdown_to_dict (cost_breakdown : Dict [str , float ]) -> List [CostBreakdownEntry ]:
17
24
return [{'resource' : resource , 'cost' : cost } for resource , cost in cost_breakdown .items ()]
18
25
@@ -30,14 +37,9 @@ def batch_record_to_dict(record: Dict[str, Any]) -> Dict[str, Any]:
30
37
else :
31
38
state = 'running'
32
39
33
- def _time_msecs_str (t ):
34
- if t :
35
- return time_msecs_str (t )
36
- return None
37
-
38
- time_created = _time_msecs_str (record ['time_created' ])
39
- time_closed = _time_msecs_str (record ['time_closed' ])
40
- time_completed = _time_msecs_str (record ['time_completed' ])
40
+ time_created = _maybe_time_msecs_str (record ['time_created' ])
41
+ time_closed = _maybe_time_msecs_str (record ['time_closed' ])
42
+ time_completed = _maybe_time_msecs_str (record ['time_completed' ])
41
43
42
44
if record ['time_created' ] and record ['time_completed' ]:
43
45
duration_ms = record ['time_completed' ] - record ['time_created' ]
@@ -50,7 +52,7 @@ def _time_msecs_str(t):
50
52
if cost_breakdown is not None :
51
53
cost_breakdown = cost_breakdown_to_dict (json .loads (cost_breakdown ))
52
54
53
- d = {
55
+ batch_response = {
54
56
'id' : record ['id' ],
55
57
'user' : record ['user' ],
56
58
'billing_project' : record ['billing_project' ],
@@ -75,9 +77,55 @@ def _time_msecs_str(t):
75
77
76
78
attributes = json .loads (record ['attributes' ])
77
79
if attributes :
78
- d ['attributes' ] = attributes
80
+ batch_response ['attributes' ] = attributes
81
+
82
+ return batch_response
83
+
84
+
85
+ def job_group_record_to_dict (record : Dict [str , Any ]) -> GetJobGroupResponseV1Alpha :
86
+ if record ['n_failed' ] > 0 :
87
+ state = 'failure'
88
+ elif record ['cancelled' ] or record ['n_cancelled' ] > 0 :
89
+ state = 'cancelled'
90
+ elif record ['state' ] == 'complete' :
91
+ assert record ['n_succeeded' ] == record ['n_jobs' ]
92
+ state = 'success'
93
+ else :
94
+ state = 'running'
95
+
96
+ time_created = _maybe_time_msecs_str (record ['time_created' ])
97
+ time_completed = _maybe_time_msecs_str (record ['time_completed' ])
79
98
80
- return d
99
+ if record ['time_created' ] and record ['time_completed' ]:
100
+ duration_ms = record ['time_completed' ] - record ['time_created' ]
101
+ else :
102
+ duration_ms = None
103
+
104
+ if record ['cost_breakdown' ] is not None :
105
+ record ['cost_breakdown' ] = cost_breakdown_to_dict (json .loads (record ['cost_breakdown' ]))
106
+
107
+ job_group_response = {
108
+ 'batch_id' : record ['batch_id' ],
109
+ 'job_group_id' : record ['job_group_id' ],
110
+ 'state' : state ,
111
+ 'complete' : record ['state' ] == 'complete' ,
112
+ 'n_jobs' : record ['n_jobs' ],
113
+ 'n_completed' : record ['n_completed' ],
114
+ 'n_succeeded' : record ['n_succeeded' ],
115
+ 'n_failed' : record ['n_failed' ],
116
+ 'n_cancelled' : record ['n_cancelled' ],
117
+ 'time_created' : time_created ,
118
+ 'time_completed' : time_completed ,
119
+ 'duration' : duration_ms ,
120
+ 'cost' : coalesce (record ['cost' ], 0 ),
121
+ 'cost_breakdown' : record ['cost_breakdown' ],
122
+ }
123
+
124
+ attributes = json .loads (record ['attributes' ])
125
+ if attributes :
126
+ job_group_response ['attributes' ] = attributes
127
+
128
+ return cast (GetJobGroupResponseV1Alpha , job_group_response )
81
129
82
130
83
131
def job_record_to_dict (record : Dict [str , Any ], name : Optional [str ]) -> JobListEntryV1Alpha :
@@ -95,38 +143,42 @@ def job_record_to_dict(record: Dict[str, Any], name: Optional[str]) -> JobListEn
95
143
if cost_breakdown is not None :
96
144
cost_breakdown = cost_breakdown_to_dict (json .loads (cost_breakdown ))
97
145
98
- return {
99
- 'batch_id' : record ['batch_id' ],
100
- 'job_id' : record ['job_id' ],
101
- 'name' : name ,
102
- 'user' : record ['user' ],
103
- 'billing_project' : record ['billing_project' ],
104
- 'state' : record ['state' ],
105
- 'exit_code' : exit_code ,
106
- 'duration' : duration ,
107
- 'cost' : coalesce (record .get ('cost' ), 0 ),
108
- 'msec_mcpu' : record ['msec_mcpu' ],
109
- 'cost_breakdown' : cost_breakdown ,
110
- }
111
-
112
-
113
- async def cancel_batch_in_db (db , batch_id ):
146
+ return cast (
147
+ JobListEntryV1Alpha ,
148
+ {
149
+ 'batch_id' : record ['batch_id' ],
150
+ 'job_id' : record ['job_id' ],
151
+ 'name' : name ,
152
+ 'user' : record ['user' ],
153
+ 'billing_project' : record ['billing_project' ],
154
+ 'state' : record ['state' ],
155
+ 'exit_code' : exit_code ,
156
+ 'duration' : duration ,
157
+ 'cost' : coalesce (record ['cost' ], 0 ),
158
+ 'msec_mcpu' : record ['msec_mcpu' ],
159
+ 'cost_breakdown' : cost_breakdown ,
160
+ },
161
+ )
162
+
163
+
164
+ async def cancel_job_group_in_db (db , batch_id , job_group_id ):
114
165
@transaction (db )
115
166
async def cancel (tx ):
116
167
record = await tx .execute_and_fetchone (
117
168
"""
118
- SELECT `state` FROM batches
119
- WHERE id = %s AND NOT deleted
169
+ SELECT 1
170
+ FROM job_groups
171
+ LEFT JOIN batches ON batches.id = job_groups.batch_id
172
+ LEFT JOIN batch_updates ON job_groups.batch_id = batch_updates.batch_id AND
173
+ job_groups.update_id = batch_updates.update_id
174
+ WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted AND (batch_updates.committed OR job_groups.job_group_id = %s)
120
175
FOR UPDATE;
121
176
""" ,
122
- (batch_id ,),
177
+ (batch_id , job_group_id , ROOT_JOB_GROUP_ID ),
123
178
)
124
179
if not record :
125
- raise NonExistentBatchError (batch_id )
126
-
127
- if record ['state' ] == 'open' :
128
- raise OpenBatchError (batch_id )
180
+ raise NonExistentJobGroupError (batch_id , job_group_id )
129
181
130
- await tx .just_execute ('CALL cancel_batch (%s);' , (batch_id ,))
182
+ await tx .just_execute ('CALL cancel_job_group (%s, %s );' , (batch_id , job_group_id ))
131
183
132
184
await cancel ()
0 commit comments