diff --git a/batch/batch/driver/canceller.py b/batch/batch/driver/canceller.py index 4ee7f0e51c18..f6caf8cf6734 100644 --- a/batch/batch/driver/canceller.py +++ b/batch/batch/driver/canceller.py @@ -94,39 +94,40 @@ async def cancel_cancelled_ready_jobs_loop_body(self): } async def user_cancelled_ready_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: - async for batch in self.db.select_and_fetchall( + async for job_group in self.db.select_and_fetchall( ''' -SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled -FROM batches +SELECT job_groups.batch_id, job_groups.job_group_id, job_groups_cancelled.id IS NOT NULL AS cancelled +FROM job_groups LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id + ON job_groups.batch_id = job_groups_cancelled.id AND + job_groups.job_group_id = job_groups_cancelled.job_group_id WHERE user = %s AND `state` = 'running'; ''', (user,), ): - if batch['cancelled']: - async for record in self.db.select_and_fetchall( + if job_group['cancelled']: + async for record in self.db.select_and_fetchall( # FIXME: Do we need a new index again? ''' SELECT jobs.job_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) -WHERE batch_id = %s AND state = 'Ready' AND always_run = 0 +WHERE batch_id = %s AND job_group_id = %s AND state = 'Ready' AND always_run = 0 LIMIT %s; ''', - (batch['id'], remaining.value), + (job_group['batch_id'], job_group['job_group_id'], remaining.value), ): - record['batch_id'] = batch['id'] + record['batch_id'] = job_group['batch_id'] yield record else: - async for record in self.db.select_and_fetchall( + async for record in self.db.select_and_fetchall( # FIXME: Do we need a new index again? ''' SELECT jobs.job_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) -WHERE batch_id = %s AND state = 'Ready' AND always_run = 0 AND cancelled = 1 +WHERE batch_id = %s AND job_group_id = %s AND state = 'Ready' AND always_run = 0 AND cancelled = 1 LIMIT %s; ''', - (batch['id'], remaining.value), + (job_group['batch_id'], job_group['job_group_id'], remaining.value), ): - record['batch_id'] = batch['id'] + record['batch_id'] = job_group['batch_id'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) @@ -182,12 +183,13 @@ async def cancel_cancelled_creating_jobs_loop_body(self): } async def user_cancelled_creating_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: - async for batch in self.db.select_and_fetchall( + async for job_group in self.db.select_and_fetchall( ''' -SELECT batches.id -FROM batches -INNER JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id +SELECT job_groups.batch_id, job_groups.job_group_id, job_groups_cancelled.id IS NOT NULL AS cancelled +FROM job_groups +LEFT JOIN job_groups_cancelled + ON job_groups.batch_id = job_groups_cancelled.id AND + job_groups.job_group_id = job_groups_cancelled.job_group_id WHERE user = %s AND `state` = 'running'; ''', (user,), @@ -198,12 +200,12 @@ async def user_cancelled_creating_jobs(user, remaining) -> AsyncIterator[Dict[st FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) STRAIGHT_JOIN attempts ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id -WHERE jobs.batch_id = %s AND state = 'Creating' AND always_run = 0 AND cancelled = 0 +WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND state = 'Creating' AND always_run = 0 AND cancelled = 0 LIMIT %s; ''', - (batch['id'], remaining.value), + (job_group['batch_id'], job_group['job_group_id'], remaining.value), ): - record['batch_id'] = batch['id'] + record['batch_id'] = job_group['batch_id'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) @@ -279,12 +281,13 @@ async def cancel_cancelled_running_jobs_loop_body(self): } async def user_cancelled_running_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: - async for batch in self.db.select_and_fetchall( + async for job_group in self.db.select_and_fetchall( ''' -SELECT batches.id -FROM batches -INNER JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id +SELECT job_groups.batch_id, job_groups.job_group_id, job_groups_cancelled.id IS NOT NULL AS cancelled +FROM job_groups +LEFT JOIN job_groups_cancelled + ON job_groups.batch_id = job_groups_cancelled.id AND + job_groups.job_group_id = job_groups_cancelled.job_group_id WHERE user = %s AND `state` = 'running'; ''', (user,), @@ -295,12 +298,12 @@ async def user_cancelled_running_jobs(user, remaining) -> AsyncIterator[Dict[str FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) STRAIGHT_JOIN attempts ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id -WHERE jobs.batch_id = %s AND state = 'Running' AND always_run = 0 AND cancelled = 0 +WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND state = 'Running' AND always_run = 0 AND cancelled = 0 LIMIT %s; ''', - (batch['id'], remaining.value), + (job_group['batch_id'], job_group['job_group_id'], remaining.value), ): - record['batch_id'] = batch['id'] + record['batch_id'] = job_group['batch_id'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) diff --git a/batch/batch/driver/instance_collection/job_private.py b/batch/batch/driver/instance_collection/job_private.py index d4800402cbc1..95f798be5c1b 100644 --- a/batch/batch/driver/instance_collection/job_private.py +++ b/batch/batch/driver/instance_collection/job_private.py @@ -179,12 +179,13 @@ async def schedule_jobs_loop_body(self): async for record in self.db.select_and_fetchall( ''' SELECT jobs.*, batches.format_version, batches.userdata, batches.user, attempts.instance_name, time_ready -FROM batches -INNER JOIN jobs ON batches.id = jobs.batch_id +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id +INNER JOIN jobs ON job_groups.batch_id = jobs.batch_id AND job_groups.job_group_id = jobs.job_group_id LEFT JOIN jobs_telemetry ON jobs.batch_id = jobs_telemetry.batch_id AND jobs.job_id = jobs_telemetry.job_id LEFT JOIN attempts ON jobs.batch_id = attempts.batch_id AND jobs.job_id = attempts.job_id LEFT JOIN instances ON attempts.instance_name = instances.name -WHERE batches.state = 'running' +WHERE job_groups.state = 'running' AND jobs.state = 'Creating' AND (jobs.always_run OR NOT jobs.cancelled) AND jobs.inst_coll = %s @@ -349,54 +350,55 @@ async def create_instances_loop_body(self): } async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: - async for batch in self.db.select_and_fetchall( + async for job_group in self.db.select_and_fetchall( ''' -SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled, userdata, user, format_version -FROM batches +SELECT job_groups.batch_id, job_groups.job_group_id, job_groups_cancelled.id IS NOT NULL AS cancelled, userdata, job_groups.user, format_version +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id -WHERE user = %s AND `state` = 'running'; + ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id +WHERE job_groups.user = %s AND job_groups.`state` = 'running'; ''', (user,), ): async for record in self.db.select_and_fetchall( ''' SELECT jobs.batch_id, jobs.job_id, jobs.spec, jobs.cores_mcpu, regions_bits_rep, COALESCE(SUM(instances.state IS NOT NULL AND - (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts + (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts, jobs.job_group_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_inst_coll_cancelled) LEFT JOIN attempts ON jobs.batch_id = attempts.batch_id AND jobs.job_id = attempts.job_id LEFT JOIN instances ON attempts.instance_name = instances.name -WHERE jobs.batch_id = %s AND jobs.state = 'Ready' AND always_run = 1 AND jobs.inst_coll = %s +WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND jobs.state = 'Ready' AND always_run = 1 AND jobs.inst_coll = %s GROUP BY jobs.job_id, jobs.spec, jobs.cores_mcpu HAVING live_attempts = 0 LIMIT %s; ''', - (batch['id'], self.name, remaining.value), + (job_group['batch_id'], job_group['job_group_id'], self.name, remaining.value), ): - record['batch_id'] = batch['id'] - record['userdata'] = batch['userdata'] - record['user'] = batch['user'] - record['format_version'] = batch['format_version'] + record['batch_id'] = job_group['batch_id'] + record['userdata'] = job_group['userdata'] + record['user'] = job_group['user'] + record['format_version'] = job_group['format_version'] yield record - if not batch['cancelled']: + if not job_group['cancelled']: async for record in self.db.select_and_fetchall( ''' SELECT jobs.batch_id, jobs.job_id, jobs.spec, jobs.cores_mcpu, regions_bits_rep, COALESCE(SUM(instances.state IS NOT NULL AND - (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts + (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts, job_group_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) LEFT JOIN attempts ON jobs.batch_id = attempts.batch_id AND jobs.job_id = attempts.job_id LEFT JOIN instances ON attempts.instance_name = instances.name -WHERE jobs.batch_id = %s AND jobs.state = 'Ready' AND always_run = 0 AND jobs.inst_coll = %s AND cancelled = 0 +WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND jobs.state = 'Ready' AND always_run = 0 AND jobs.inst_coll = %s AND cancelled = 0 GROUP BY jobs.job_id, jobs.spec, jobs.cores_mcpu HAVING live_attempts = 0 LIMIT %s ''', - (batch['id'], self.name, remaining.value), + (job_group['batch_id'], job_group['job_group_id'], self.name, remaining.value), ): - record['batch_id'] = batch['id'] - record['userdata'] = batch['userdata'] - record['user'] = batch['user'] - record['format_version'] = batch['format_version'] + record['batch_id'] = job_group['batch_id'] + record['userdata'] = job_group['userdata'] + record['user'] = job_group['user'] + record['format_version'] = job_group['format_version'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) @@ -420,6 +422,7 @@ async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: id = (batch_id, job_id) attempt_id = secret_alnum_string(6) record['attempt_id'] = attempt_id + job_group_id = record['job_group_id'] if n_user_instances_created >= n_allocated_instances: if random.random() > self.exceeded_shares_counter.rate(): @@ -435,7 +438,7 @@ async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: log.info(f'creating job private instance for job {id}') async def create_instance_with_error_handling( - batch_id: int, job_id: int, attempt_id: str, record: dict, id: Tuple[int, int] + batch_id: int, job_id: int, attempt_id: str, job_group_id: int, record: dict, id: Tuple[int, int] ): try: batch_format_version = BatchFormatVersion(record['format_version']) @@ -460,6 +463,7 @@ async def create_instance_with_error_handling( batch_id, job_id, attempt_id, + job_group_id, record['user'], record['format_version'], traceback.format_exc(), @@ -467,7 +471,9 @@ async def create_instance_with_error_handling( except Exception: log.exception(f'while creating job private instance for job {id}', exc_info=True) - await waitable_pool.call(create_instance_with_error_handling, batch_id, job_id, attempt_id, record, id) + await waitable_pool.call( + create_instance_with_error_handling, batch_id, job_id, attempt_id, job_group_id, record, id + ) remaining.value -= 1 if remaining.value <= 0: diff --git a/batch/batch/driver/instance_collection/pool.py b/batch/batch/driver/instance_collection/pool.py index c8a6dbd8cd72..0554d7c2fc62 100644 --- a/batch/batch/driver/instance_collection/pool.py +++ b/batch/batch/driver/instance_collection/pool.py @@ -325,28 +325,28 @@ async def regions_to_ready_cores_mcpu_from_estimated_job_queue(self) -> List[Tup SELECT scheduling_iteration, user_idx, n_regions, regions_bits_rep, CAST(COALESCE(SUM(cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu FROM ( SELECT {user_idx} AS user_idx, batch_id, job_id, cores_mcpu, always_run, n_regions, regions_bits_rep, - ROW_NUMBER() OVER (ORDER BY batch_id, always_run DESC, -n_regions DESC, regions_bits_rep, job_id ASC) DIV {share} AS scheduling_iteration + ROW_NUMBER() OVER (ORDER BY batch_id, job_group_id, always_run DESC, -n_regions DESC, regions_bits_rep, job_id ASC) DIV {share} AS scheduling_iteration FROM ( ( - SELECT jobs.batch_id, jobs.job_id, cores_mcpu, always_run, n_regions, regions_bits_rep + SELECT jobs.batch_id, jobs.job_id, jobs.job_group_id, cores_mcpu, always_run, n_regions, regions_bits_rep FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) LEFT JOIN batches ON jobs.batch_id = batches.id WHERE user = %s AND batches.`state` = 'running' AND jobs.state = 'Ready' AND always_run AND inst_coll = %s - ORDER BY jobs.batch_id ASC, jobs.job_id ASC + ORDER BY jobs.batch_id ASC, jobs.job_group_id ASC, jobs.job_id ASC LIMIT {share * self.job_queue_scheduling_window_secs} ) UNION ( - SELECT jobs.batch_id, jobs.job_id, cores_mcpu, always_run, n_regions, regions_bits_rep + SELECT jobs.batch_id, jobs.job_id, jobs.job_group_id, cores_mcpu, always_run, n_regions, regions_bits_rep FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) LEFT JOIN batches ON jobs.batch_id = batches.id - LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id + LEFT JOIN job_groups_cancelled ON jobs.batch_id = job_groups_cancelled.id AND jobs.job_group_id = job_groups_cancelled.job_group_id WHERE user = %s AND batches.`state` = 'running' AND jobs.state = 'Ready' AND NOT always_run AND job_groups_cancelled.id IS NULL AND inst_coll = %s - ORDER BY jobs.batch_id ASC, jobs.job_id ASC + ORDER BY jobs.batch_id ASC, jobs.job_group_id ASC, jobs.job_id ASC LIMIT {share * self.job_queue_scheduling_window_secs} ) ) AS t1 - ORDER BY batch_id, always_run DESC, -n_regions DESC, regions_bits_rep, job_id ASC + ORDER BY batch_id, job_group_id, always_run DESC, -n_regions DESC, regions_bits_rep, job_id ASC LIMIT {share * self.job_queue_scheduling_window_secs} ) AS t2 GROUP BY scheduling_iteration, user_idx, regions_bits_rep, n_regions @@ -605,51 +605,55 @@ async def schedule_loop_body(self): } async def user_runnable_jobs(user): - async for batch in self.db.select_and_fetchall( + async for job_group in self.db.select_and_fetchall( ''' -SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled, userdata, user, format_version -FROM batches +SELECT job_groups.batch_id, job_groups.job_group_id, job_groups_cancelled.id IS NOT NULL AS cancelled, userdata, job_groups.user, format_version +FROM job_groups +LEFT JOIN batches ON job_groups.batch_id = batches.id LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id -WHERE user = %s AND `state` = 'running'; + ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id +WHERE job_groups.user = %s AND job_groups.`state` = 'running' +ORDER BY job_groups.batch_id, job_groups.job_group_id; ''', (user,), "user_runnable_jobs__select_running_batches", ): async for record in self.db.select_and_fetchall( ''' -SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready +SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready, job_group_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_inst_coll_cancelled) LEFT JOIN jobs_telemetry ON jobs.batch_id = jobs_telemetry.batch_id AND jobs.job_id = jobs_telemetry.job_id -WHERE jobs.batch_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 1 +WHERE jobs.batch_id = %s AND job_group_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 1 ORDER BY jobs.batch_id, inst_coll, state, always_run, -n_regions DESC, regions_bits_rep, jobs.job_id LIMIT 300; ''', - (batch['id'], self.pool.name), + (job_group['batch_id'], job_group['job_group_id'], self.pool.name), "user_runnable_jobs__select_ready_always_run_jobs", ): - record['batch_id'] = batch['id'] - record['userdata'] = batch['userdata'] - record['user'] = batch['user'] - record['format_version'] = batch['format_version'] + record['batch_id'] = job_group['batch_id'] + record['job_group_id'] = job_group['job_group_id'] + record['userdata'] = job_group['userdata'] + record['user'] = job_group['user'] + record['format_version'] = job_group['format_version'] yield record - if not batch['cancelled']: - async for record in self.db.select_and_fetchall( + if not job_group['cancelled']: + async for record in self.db.select_and_fetchall( # FIXME: Do we need a different index? ''' -SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready +SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready, job_group_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) LEFT JOIN jobs_telemetry ON jobs.batch_id = jobs_telemetry.batch_id AND jobs.job_id = jobs_telemetry.job_id -WHERE jobs.batch_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 0 AND cancelled = 0 +WHERE jobs.batch_id = %s AND job_group_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 0 AND cancelled = 0 ORDER BY jobs.batch_id, inst_coll, state, always_run, -n_regions DESC, regions_bits_rep, jobs.job_id LIMIT 300; ''', - (batch['id'], self.pool.name), + (job_group['batch_id'], job_group['job_group_id'], self.pool.name), "user_runnable_jobs__select_ready_jobs_batch_not_cancelled", ): - record['batch_id'] = batch['id'] - record['userdata'] = batch['userdata'] - record['user'] = batch['user'] - record['format_version'] = batch['format_version'] + record['batch_id'] = job_group['batch_id'] + record['job_group_id'] = job_group['job_group_id'] + record['userdata'] = job_group['userdata'] + record['user'] = job_group['user'] + record['format_version'] = job_group['format_version'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) @@ -681,6 +685,7 @@ async def user_runnable_jobs(user): record['batch_id'], record['job_id'], attempt_id, + record['job_group_id'], record['user'], BatchFormatVersion(record['format_version']), f'no regions given in {regions} are supported. choose from a region in {supported_regions}', diff --git a/batch/batch/driver/job.py b/batch/batch/driver/job.py index a4b54705e3eb..7026ddc0bca6 100644 --- a/batch/batch/driver/job.py +++ b/batch/batch/driver/job.py @@ -16,6 +16,7 @@ from ..batch import batch_record_to_dict from ..batch_configuration import KUBERNETES_SERVER_URL from ..batch_format_version import BatchFormatVersion +from ..constants import ROOT_JOB_GROUP_ID from ..file_store import FileStore from ..globals import STATUS_FORMAT_VERSION, complete_states, tasks from ..instance_config import QuantifiedResource @@ -39,26 +40,27 @@ async def notify_batch_job_complete(db: Database, client_session: httpx.ClientSe job_groups_n_jobs_in_complete_states.n_succeeded, job_groups_n_jobs_in_complete_states.n_failed, job_groups_n_jobs_in_complete_states.n_cancelled -FROM batches +FROM job_groups +LEFT JOIN batches ON job_groups.batch_id = batches.id LEFT JOIN job_groups_n_jobs_in_complete_states - ON batches.id = job_groups_n_jobs_in_complete_states.id + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown FROM ( - SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + SELECT batch_id, job_group_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_group_resources_v3 - WHERE batches.id = aggregated_job_group_resources_v3.batch_id - GROUP BY batch_id, resource_id + WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY batch_id, job_group_id, resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id - GROUP BY batch_id + GROUP BY batch_id, job_group_id ) AS cost_t ON TRUE LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id -WHERE batches.id = %s AND NOT deleted AND callback IS NOT NULL AND - batches.`state` = 'complete'; + ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id +WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted AND job_groups.callback IS NOT NULL AND + job_groups.`state` = 'complete'; ''', - (batch_id,), + (batch_id, ROOT_JOB_GROUP_ID), 'notify_batch_job_complete', ) @@ -333,7 +335,7 @@ async def make_request(): log.info(f'unschedule job {id}, attempt {attempt_id}: called delete job') -async def job_config(app, record, attempt_id): +async def job_config(app, record, attempt_id, job_group_id): k8s_cache: K8sCache = app['k8s_cache'] db: Database = app['db'] @@ -352,6 +354,7 @@ async def job_config(app, record, attempt_id): job_spec = db_spec job_spec['attempt_id'] = attempt_id + job_spec['job_group_id'] = job_group_id userdata = json.loads(record['userdata']) @@ -436,6 +439,7 @@ async def job_config(app, record, attempt_id): return { 'batch_id': batch_id, 'job_id': job_id, + 'job_group_id': job_group_id, 'format_version': format_version.format_version, 'token': spec_token, 'start_job_id': start_job_id, @@ -446,7 +450,7 @@ async def job_config(app, record, attempt_id): } -async def mark_job_errored(app, batch_id, job_id, attempt_id, user, format_version, error_msg): +async def mark_job_errored(app, batch_id, job_id, attempt_id, job_group_id, user, format_version, error_msg): file_store: FileStore = app['file_store'] status = { @@ -454,6 +458,7 @@ async def mark_job_errored(app, batch_id, job_id, attempt_id, user, format_versi 'worker': None, 'batch_id': batch_id, 'job_id': job_id, + 'job_group_id': job_group_id, 'attempt_id': attempt_id, 'user': user, 'state': 'error', @@ -478,17 +483,18 @@ async def schedule_job(app, record, instance): batch_id = record['batch_id'] job_id = record['job_id'] attempt_id = record['attempt_id'] + job_group_id = record['job_group_id'] format_version = BatchFormatVersion(record['format_version']) id = (batch_id, job_id) try: - body = await job_config(app, record, attempt_id) + body = await job_config(app, record, attempt_id, job_group_id) except Exception: log.exception(f'while making job config for job {id} with attempt id {attempt_id}') await mark_job_errored( - app, batch_id, job_id, attempt_id, record['user'], format_version, traceback.format_exc() + app, batch_id, job_id, attempt_id, job_group_id, record['user'], format_version, traceback.format_exc() ) raise diff --git a/batch/batch/driver/main.py b/batch/batch/driver/main.py index 880078e36769..253143341f69 100644 --- a/batch/batch/driver/main.py +++ b/batch/batch/driver/main.py @@ -62,6 +62,7 @@ ) from ..cloud.driver import get_cloud_driver from ..cloud.resource_utils import local_ssd_size, possible_cores_from_worker_type, unreserved_worker_data_disk_size_gib +from ..constants import ROOT_JOB_GROUP_ID from ..exceptions import BatchUserError from ..file_store import FileStore from ..globals import HTTP_CLIENT_MAX_SIZE @@ -1018,13 +1019,13 @@ async def check(tx): CAST(COALESCE(SUM(state = 'Creating' AND cancelled), 0) AS SIGNED) AS actual_n_cancelled_creating_jobs FROM ( - SELECT batches.user, jobs.state, jobs.cores_mcpu, jobs.inst_coll, + SELECT job_groups.user, jobs.state, jobs.cores_mcpu, jobs.inst_coll, (jobs.always_run OR NOT (jobs.cancelled OR job_groups_cancelled.id IS NOT NULL)) AS runnable, (NOT jobs.always_run AND (jobs.cancelled OR job_groups_cancelled.id IS NOT NULL)) AS cancelled - FROM batches - INNER JOIN jobs ON batches.id = jobs.batch_id - LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id - WHERE batches.`state` = 'running' + FROM jobs + INNER JOIN job_groups ON job_groups.batch_id = jobs.batch_id AND job_groups.job_group_id = jobs.job_group_id + LEFT JOIN job_groups_cancelled ON jobs.batch_id = job_groups_cancelled.id AND jobs.job_group_id = job_groups_cancelled.job_group_id + WHERE job_groups.`state` = 'running' ) as v GROUP BY user, inst_coll ) as t @@ -1109,39 +1110,41 @@ def fold(d, key_f): async def check(tx): attempt_resources = tx.execute_and_fetchall( ''' -SELECT attempt_resources.batch_id, attempt_resources.job_id, attempt_resources.attempt_id, +SELECT attempt_resources.batch_id, jobs.job_group_id, attempt_resources.job_id, attempt_resources.attempt_id, JSON_OBJECTAGG(resources.resource, quantity * GREATEST(COALESCE(rollup_time - start_time, 0), 0)) as resources FROM attempt_resources INNER JOIN attempts ON attempts.batch_id = attempt_resources.batch_id AND attempts.job_id = attempt_resources.job_id AND attempts.attempt_id = attempt_resources.attempt_id +LEFT JOIN jobs ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id LEFT JOIN resources ON attempt_resources.resource_id = resources.resource_id WHERE GREATEST(COALESCE(rollup_time - start_time, 0), 0) != 0 -GROUP BY batch_id, job_id, attempt_id +GROUP BY batch_id, job_group_id, job_id, attempt_id LOCK IN SHARE MODE; ''' ) agg_job_resources = tx.execute_and_fetchall( ''' -SELECT batch_id, job_id, JSON_OBJECTAGG(resource, `usage`) as resources +SELECT batch_id, job_group_id, job_id, JSON_OBJECTAGG(resource, `usage`) as resources FROM aggregated_job_resources_v3 +LEFT JOIN jobs ON aggregated_job_resources_v3.batch_id = jobs.batch_id AND aggregated_job_resources_v3.job_group_id = jobs.job_group_id LEFT JOIN resources ON aggregated_job_resources_v3.resource_id = resources.resource_id -GROUP BY batch_id, job_id +GROUP BY batch_id, job_group_id, job_id LOCK IN SHARE MODE; ''' ) - agg_batch_resources = tx.execute_and_fetchall( + agg_job_group_resources = tx.execute_and_fetchall( ''' -SELECT batch_id, billing_project, JSON_OBJECTAGG(resource, `usage`) as resources +SELECT batch_id, job_group_id, billing_project, JSON_OBJECTAGG(resource, `usage`) as resources FROM ( - SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + SELECT batch_id, job_group_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_group_resources_v3 - GROUP BY batch_id, resource_id) AS t + GROUP BY batch_id, job_group_id, resource_id) AS t LEFT JOIN resources ON t.resource_id = resources.resource_id -JOIN batches ON batches.id = t.batch_id +JOIN job_groups ON job_groups.batch_id = t.batch_id GROUP BY t.batch_id, billing_project LOCK IN SHARE MODE; ''' @@ -1161,18 +1164,20 @@ async def check(tx): ) attempt_resources = { - (record['batch_id'], record['job_id'], record['attempt_id']): json_to_value(record['resources']) + (record['batch_id'], record['job_group_id'], record['job_id'], record['attempt_id']): json_to_value( + record['resources'] + ) async for record in attempt_resources } agg_job_resources = { - (record['batch_id'], record['job_id']): json_to_value(record['resources']) + (record['batch_id'], record['job_group_id'], record['job_id']): json_to_value(record['resources']) async for record in agg_job_resources } - agg_batch_resources = { - (record['batch_id'], record['billing_project']): json_to_value(record['resources']) - async for record in agg_batch_resources + agg_job_group_resources = { + (record['batch_id'], record['job_group_id'], record['billing_project']): json_to_value(record['resources']) + async for record in agg_job_group_resources } agg_billing_project_resources = { @@ -1180,31 +1185,34 @@ async def check(tx): async for record in agg_billing_project_resources } - attempt_by_batch_resources = fold(attempt_resources, lambda k: k[0]) - attempt_by_job_resources = fold(attempt_resources, lambda k: (k[0], k[1])) - job_by_batch_resources = fold(agg_job_resources, lambda k: k[0]) - batch_by_billing_project_resources = fold(agg_batch_resources, lambda k: k[1]) + attempt_by_job_group_resources = fold(attempt_resources, lambda k: (k[0], k[1])) + attempt_by_job_resources = fold(attempt_resources, lambda k: (k[0], k[2])) + job_by_job_group_resources = fold(agg_job_resources, lambda k: (k[0], k[1])) + job_group_by_billing_project_resources = fold(agg_job_group_resources, lambda k: k[1]) - agg_batch_resources_2 = {batch_id: resources for (batch_id, _), resources in agg_batch_resources.items()} + agg_job_group_resources_2 = { + (batch_id, job_group_id): resources + for (batch_id, job_group_id, _), resources in agg_job_group_resources.items() + } - assert attempt_by_batch_resources == agg_batch_resources_2, ( - dictdiffer.diff(attempt_by_batch_resources, agg_batch_resources_2), - attempt_by_batch_resources, - agg_batch_resources_2, + assert attempt_by_job_group_resources == agg_job_group_resources_2, ( + dictdiffer.diff(attempt_by_job_group_resources, agg_job_group_resources_2), + attempt_by_job_group_resources, + agg_job_group_resources_2, ) assert attempt_by_job_resources == agg_job_resources, ( dictdiffer.diff(attempt_by_job_resources, agg_job_resources), attempt_by_job_resources, agg_job_resources, ) - assert job_by_batch_resources == agg_batch_resources_2, ( - dictdiffer.diff(job_by_batch_resources, agg_batch_resources_2), - job_by_batch_resources, - agg_batch_resources_2, + assert job_by_job_group_resources == agg_job_group_resources_2, ( + dictdiffer.diff(job_by_job_group_resources, agg_job_group_resources_2), + job_by_job_group_resources, + agg_job_group_resources_2, ) - assert batch_by_billing_project_resources == agg_billing_project_resources, ( - dictdiffer.diff(batch_by_billing_project_resources, agg_billing_project_resources), - batch_by_billing_project_resources, + assert job_group_by_billing_project_resources == agg_billing_project_resources, ( + dictdiffer.diff(job_group_by_billing_project_resources, agg_billing_project_resources), + job_group_by_billing_project_resources, agg_billing_project_resources, ) @@ -1245,15 +1253,16 @@ async def cancel_fast_failing_batches(app): records = db.select_and_fetchall( ''' -SELECT batches.id, job_groups_n_jobs_in_complete_states.n_failed -FROM batches +SELECT job_groups.batch_id, job_groups_n_jobs_in_complete_states.n_failed +FROM job_groups LEFT JOIN job_groups_n_jobs_in_complete_states - ON batches.id = job_groups_n_jobs_in_complete_states.id -WHERE state = 'running' AND cancel_after_n_failures IS NOT NULL AND n_failed >= cancel_after_n_failures -''' + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id +WHERE state = 'running' AND cancel_after_n_failures IS NOT NULL AND n_failed >= cancel_after_n_failures AND job_groups.job_group_id = %s +''', + (ROOT_JOB_GROUP_ID,), ) async for batch in records: - await _cancel_batch(app, batch['id']) + await _cancel_batch(app, batch['batch_id']) USER_CORES = pc.Gauge('batch_user_cores', 'Batch user cores (i.e. total in-use cores)', ['state', 'user', 'inst_coll']) diff --git a/batch/batch/front_end/front_end.py b/batch/batch/front_end/front_end.py index 02ccf71626f1..95cf3c35aad9 100644 --- a/batch/batch/front_end/front_end.py +++ b/batch/batch/front_end/front_end.py @@ -1456,11 +1456,11 @@ async def update(tx: Transaction): ''' SELECT job_groups_cancelled.id IS NOT NULL AS cancelled FROM batches -LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id -WHERE batches.id = %s AND user = %s AND NOT deleted +LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id AND job_groups_cancelled.job_group_id = %s +WHERE batches.id = %s AND batches.user = %s AND NOT deleted FOR UPDATE; ''', - (batch_id, user), + (ROOT_JOB_GROUP_ID, batch_id, user), ) if not record: raise web.HTTPNotFound() @@ -1512,25 +1512,26 @@ async def _get_batch(app, batch_id): job_groups_n_jobs_in_complete_states.n_failed, job_groups_n_jobs_in_complete_states.n_cancelled, cost_t.* -FROM batches +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id LEFT JOIN job_groups_n_jobs_in_complete_states - ON batches.id = job_groups_n_jobs_in_complete_states.id + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id + ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown FROM ( - SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + SELECT batch_id, job_group_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_group_resources_v3 - WHERE batches.id = aggregated_job_group_resources_v3.batch_id - GROUP BY batch_id, resource_id + WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY batch_id, job_group_id, resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id - GROUP BY batch_id + GROUP BY batch_id, job_group_id ) AS cost_t ON TRUE -WHERE batches.id = %s AND NOT deleted; +WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted; ''', - (batch_id,), + (batch_id, ROOT_JOB_GROUP_ID), ) if not record: raise web.HTTPNotFound() @@ -1593,11 +1594,11 @@ async def close_batch(request, userdata): record = await db.select_and_fetchone( ''' SELECT job_groups_cancelled.id IS NOT NULL AS cancelled -FROM batches -LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id -WHERE user = %s AND batches.id = %s AND NOT deleted; +FROM job_groups +LEFT JOIN job_groups_cancelled ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id +WHERE user = %s AND job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted; ''', - (user, batch_id), + (user, batch_id, ROOT_JOB_GROUP_ID), ) if not record: raise web.HTTPNotFound() @@ -1630,12 +1631,13 @@ async def commit_update(request: web.Request, userdata): record = await db.select_and_fetchone( ''' SELECT start_job_id, job_groups_cancelled.id IS NOT NULL AS cancelled -FROM batches -LEFT JOIN batch_updates ON batches.id = batch_updates.batch_id -LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id -WHERE user = %s AND batches.id = %s AND batch_updates.update_id = %s AND NOT deleted; +FROM job_groups +LEFT JOIN batches ON job_groups.batch_id = batches.id +LEFT JOIN batch_updates ON job_groups.batch_id = batch_updates.batch_id +LEFT JOIN job_groups_cancelled ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id +WHERE job_groups.user = %s AND job_groups.batch_id = %s AND job_groups.job_group_id = %s AND batch_updates.update_id = %s AND NOT deleted; ''', - (user, batch_id, update_id), + (user, batch_id, ROOT_JOB_GROUP_ID, update_id), ) if not record: raise web.HTTPNotFound() diff --git a/batch/batch/front_end/query/query.py b/batch/batch/front_end/query/query.py index c28d804e9f05..5d57929c7440 100644 --- a/batch/batch/front_end/query/query.py +++ b/batch/batch/front_end/query/query.py @@ -361,16 +361,16 @@ def __init__(self, state: BatchState, operator: ExactMatchOperator): def query(self) -> Tuple[str, List[Any]]: args: List[Any] if self.state == BatchState.OPEN: - condition = "(`state` = 'open')" + condition = "(batches.`state` = 'open')" args = [] elif self.state == BatchState.CLOSED: - condition = "(`state` != 'open')" + condition = "(batches.`state` != 'open')" args = [] elif self.state == BatchState.COMPLETE: - condition = "(`state` = 'complete')" + condition = "(batches.`state` = 'complete')" args = [] elif self.state == BatchState.RUNNING: - condition = "(`state` = 'running')" + condition = "(batches.`state` = 'running')" args = [] elif self.state == BatchState.CANCELLED: condition = '(job_groups_cancelled.id IS NOT NULL)' @@ -381,7 +381,7 @@ def query(self) -> Tuple[str, List[Any]]: else: assert self.state == BatchState.SUCCESS # need complete because there might be no jobs - condition = "(`state` = 'complete' AND n_succeeded = n_jobs)" + condition = "(batches.`state` = 'complete' AND n_succeeded = batches.n_jobs)" args = [] if isinstance(self.operator, NotEqualExactMatchOperator): @@ -442,58 +442,58 @@ def query(self) -> Tuple[str, List[str]]: return (f'(batches.billing_project {op} %s)', [self.billing_project]) -class BatchQuotedExactMatchQuery(Query): +class JobGroupQuotedExactMatchQuery(Query): @staticmethod - def parse(term: str) -> 'BatchQuotedExactMatchQuery': + def parse(term: str) -> 'JobGroupQuotedExactMatchQuery': if len(term) < 3: raise QueryError(f'expected a string of minimum length 3. Found {term}') if term[-1] != '"': raise QueryError("expected the last character of the string to be '\"'") - return BatchQuotedExactMatchQuery(term[1:-1]) + return JobGroupQuotedExactMatchQuery(term[1:-1]) def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: sql = ''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` = %s OR `value` = %s)) ''' return (sql, [self.term, self.term]) -class BatchUnquotedPartialMatchQuery(Query): +class JobGroupUnquotedPartialMatchQuery(Query): @staticmethod - def parse(term: str) -> 'BatchUnquotedPartialMatchQuery': + def parse(term: str) -> 'JobGroupUnquotedPartialMatchQuery': if len(term) < 1: raise QueryError(f'expected a string of minimum length 1. Found {term}') if term[0] == '"': raise QueryError("expected the first character of the string to not be '\"'") if term[-1] == '"': raise QueryError("expected the last character of the string to not be '\"'") - return BatchUnquotedPartialMatchQuery(term) + return JobGroupUnquotedPartialMatchQuery(term) def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: sql = ''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` LIKE %s OR `value` LIKE %s)) ''' escaped_term = f'%{self.term}%' return (sql, [escaped_term, escaped_term]) -class BatchKeywordQuery(Query): +class JobGroupKeywordQuery(Query): @staticmethod - def parse(op: str, key: str, value: str) -> 'BatchKeywordQuery': + def parse(op: str, key: str, value: str) -> 'JobGroupKeywordQuery': operator = get_operator(op) if not isinstance(operator, MatchOperator): raise QueryError(f'unexpected operator "{op}" expected one of {MatchOperator.symbols}') - return BatchKeywordQuery(operator, key, value) + return JobGroupKeywordQuery(operator, key, value) def __init__(self, operator: MatchOperator, key: str, value: str): self.operator = operator @@ -506,21 +506,21 @@ def query(self) -> Tuple[str, List[str]]: if isinstance(self.operator, PartialMatchOperator): value = f'%{value}%' sql = f''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` = %s AND `value` {op} %s)) ''' return (sql, [self.key, value]) -class BatchStartTimeQuery(Query): +class JobGroupStartTimeQuery(Query): @staticmethod - def parse(op: str, time: str) -> 'BatchStartTimeQuery': + def parse(op: str, time: str) -> 'JobGroupStartTimeQuery': operator = get_operator(op) if not isinstance(operator, ComparisonOperator): raise QueryError(f'unexpected operator "{op}" expected one of {ComparisonOperator.symbols}') time_msecs = parse_date(time) - return BatchStartTimeQuery(operator, time_msecs) + return JobGroupStartTimeQuery(operator, time_msecs) def __init__(self, operator: ComparisonOperator, time_msecs: int): self.operator = operator @@ -528,18 +528,18 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f'(batches.time_created {op} %s)' + sql = f'(job_groups.time_created {op} %s)' return (sql, [self.time_msecs]) -class BatchEndTimeQuery(Query): +class JobGroupEndTimeQuery(Query): @staticmethod - def parse(op: str, time: str) -> 'BatchEndTimeQuery': + def parse(op: str, time: str) -> 'JobGroupEndTimeQuery': operator = get_operator(op) if not isinstance(operator, ComparisonOperator): raise QueryError(f'unexpected operator "{op}" expected one of {ComparisonOperator.symbols}') time_msecs = parse_date(time) - return BatchEndTimeQuery(operator, time_msecs) + return JobGroupEndTimeQuery(operator, time_msecs) def __init__(self, operator: ComparisonOperator, time_msecs: int): self.operator = operator @@ -547,18 +547,18 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f'(batches.time_completed {op} %s)' + sql = f'(job_groups.time_completed {op} %s)' return (sql, [self.time_msecs]) -class BatchDurationQuery(Query): +class JobGroupDurationQuery(Query): @staticmethod - def parse(op: str, time: str) -> 'BatchDurationQuery': + def parse(op: str, time: str) -> 'JobGroupDurationQuery': operator = get_operator(op) if not isinstance(operator, ComparisonOperator): raise QueryError(f'unexpected operator "{op}" expected one of {ComparisonOperator.symbols}') time_msecs = int(parse_float(time) * 1000) - return BatchDurationQuery(operator, time_msecs) + return JobGroupDurationQuery(operator, time_msecs) def __init__(self, operator: ComparisonOperator, time_msecs: int): self.operator = operator @@ -566,18 +566,18 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f'((batches.time_completed - batches.time_created) {op} %s)' + sql = f'((job_groups.time_completed - job_groups.time_created) {op} %s)' return (sql, [self.time_msecs]) -class BatchCostQuery(Query): +class JobGroupCostQuery(Query): @staticmethod - def parse(op: str, cost_str: str) -> 'BatchCostQuery': + def parse(op: str, cost_str: str) -> 'JobGroupCostQuery': operator = get_operator(op) if not isinstance(operator, ComparisonOperator): raise QueryError(f'unexpected operator "{op}" expected one of {ComparisonOperator.symbols}') cost = parse_cost(cost_str) - return BatchCostQuery(operator, cost) + return JobGroupCostQuery(operator, cost) def __init__(self, operator: ComparisonOperator, cost: float): self.operator = operator diff --git a/batch/batch/front_end/query/query_v1.py b/batch/batch/front_end/query/query_v1.py index a52b1cf2c251..cacc6b7c76b3 100644 --- a/batch/batch/front_end/query/query_v1.py +++ b/batch/batch/front_end/query/query_v1.py @@ -1,5 +1,6 @@ from typing import Any, List, Optional, Tuple +from ...constants import ROOT_JOB_GROUP_ID from ...exceptions import QueryError from .query import job_state_search_term_to_states @@ -8,11 +9,12 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) where_conditions = [ '(billing_project_users.`user` = %s AND billing_project_users.billing_project = batches.billing_project)', 'NOT deleted', + 'job_groups.job_group_id = %s', ] - where_args: List[Any] = [user] + where_args: List[Any] = [user, ROOT_JOB_GROUP_ID] if last_batch_id is not None: - where_conditions.append('(batches.id < %s)') + where_conditions.append('(job_groups.batch_id < %s)') where_args.append(last_batch_id) terms = q.split() @@ -27,16 +29,16 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) if '=' in t: k, v = t.split('=', 1) condition = ''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` = %s AND `value` = %s)) ''' args = [k, v] elif t.startswith('has:'): k = t[4:] condition = ''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` = %s)) ''' args = [k] @@ -53,16 +55,16 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) ''' args = [k] elif t == 'open': - condition = "(`state` = 'open')" + condition = "(batches.`state` = 'open')" args = [] elif t == 'closed': - condition = "(`state` != 'open')" + condition = "(batches.`state` != 'open')" args = [] elif t == 'complete': - condition = "(`state` = 'complete')" + condition = "(batches.`state` = 'complete')" args = [] elif t == 'running': - condition = "(`state` = 'running')" + condition = "(batches.`state` = 'running')" args = [] elif t == 'cancelled': condition = '(job_groups_cancelled.id IS NOT NULL)' @@ -72,7 +74,7 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) args = [] elif t == 'success': # need complete because there might be no jobs - condition = "(`state` = 'complete' AND n_succeeded = n_jobs)" + condition = "(batches.`state` = 'complete' AND n_succeeded = batches.n_jobs)" args = [] else: raise QueryError(f'Invalid search term: {t}.') @@ -85,21 +87,22 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) sql = f''' WITH base_t AS ( - SELECT batches.*, + SELECT batches.*, job_groups.batch_id, job_groups.job_group_id, job_groups_cancelled.id IS NOT NULL AS cancelled, job_groups_n_jobs_in_complete_states.n_completed, job_groups_n_jobs_in_complete_states.n_succeeded, job_groups_n_jobs_in_complete_states.n_failed, job_groups_n_jobs_in_complete_states.n_cancelled - FROM batches + FROM job_groups + LEFT JOIN batches ON batches.id = job_groups.batch_id LEFT JOIN billing_projects ON batches.billing_project = billing_projects.name LEFT JOIN job_groups_n_jobs_in_complete_states - ON batches.id = job_groups_n_jobs_in_complete_states.id + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id + ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id STRAIGHT_JOIN billing_project_users ON batches.billing_project = billing_project_users.billing_project WHERE {' AND '.join(where_conditions)} - ORDER BY id DESC + ORDER BY batch_id DESC LIMIT 51 ) SELECT base_t.*, cost_t.cost, cost_t.cost_breakdown @@ -107,15 +110,15 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown FROM ( - SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + SELECT batch_id, job_group_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_group_resources_v3 - WHERE base_t.id = aggregated_job_group_resources_v3.batch_id - GROUP BY batch_id, resource_id + WHERE base_t.id = aggregated_job_group_resources_v3.batch_id AND base_t.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY batch_id, job_group_id, resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id - GROUP BY batch_id + GROUP BY batch_id, job_group_id ) AS cost_t ON TRUE -ORDER BY id DESC; +ORDER BY batch_id DESC; ''' return (sql, where_args) diff --git a/batch/batch/front_end/query/query_v2.py b/batch/batch/front_end/query/query_v2.py index ad2df661ff80..924c263d0398 100644 --- a/batch/batch/front_end/query/query_v2.py +++ b/batch/batch/front_end/query/query_v2.py @@ -1,5 +1,6 @@ from typing import Any, List, Optional, Tuple +from ...constants import ROOT_JOB_GROUP_ID from ...exceptions import QueryError from .operators import ( GreaterThanEqualOperator, @@ -10,19 +11,19 @@ ) from .query import ( BatchBillingProjectQuery, - BatchCostQuery, - BatchDurationQuery, - BatchEndTimeQuery, BatchIdQuery, - BatchKeywordQuery, - BatchQuotedExactMatchQuery, - BatchStartTimeQuery, BatchStateQuery, - BatchUnquotedPartialMatchQuery, BatchUserQuery, JobCostQuery, JobDurationQuery, JobEndTimeQuery, + JobGroupCostQuery, + JobGroupDurationQuery, + JobGroupEndTimeQuery, + JobGroupKeywordQuery, + JobGroupQuotedExactMatchQuery, + JobGroupStartTimeQuery, + JobGroupUnquotedPartialMatchQuery, JobIdQuery, JobInstanceCollectionQuery, JobInstanceQuery, @@ -58,8 +59,8 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) queries: List[Query] = [] # logic to make time interval queries fast - min_start_gt_query: Optional[BatchStartTimeQuery] = None - max_end_lt_query: Optional[BatchEndTimeQuery] = None + min_start_gt_query: Optional[JobGroupStartTimeQuery] = None + max_end_lt_query: Optional[JobGroupEndTimeQuery] = None if q: terms = q.rstrip().lstrip().split('\n') @@ -69,9 +70,9 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) if len(statement) == 1: word = statement[0] if word[0] == '"': - queries.append(BatchQuotedExactMatchQuery.parse(word)) + queries.append(JobGroupQuotedExactMatchQuery.parse(word)) else: - queries.append(BatchUnquotedPartialMatchQuery.parse(word)) + queries.append(JobGroupUnquotedPartialMatchQuery.parse(word)) elif len(statement) == 3: left, op, right = statement if left == 'batch_id': @@ -83,42 +84,39 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) elif left == 'state': queries.append(BatchStateQuery.parse(op, right)) elif left == 'start_time': - st_query = BatchStartTimeQuery.parse(op, right) + st_query = JobGroupStartTimeQuery.parse(op, right) queries.append(st_query) if (type(st_query.operator) in [GreaterThanOperator, GreaterThanEqualOperator]) and ( min_start_gt_query is None or min_start_gt_query.time_msecs >= st_query.time_msecs ): min_start_gt_query = st_query elif left == 'end_time': - et_query = BatchEndTimeQuery.parse(op, right) + et_query = JobGroupEndTimeQuery.parse(op, right) queries.append(et_query) if (type(et_query.operator) in [LessThanOperator, LessThanEqualOperator]) and ( max_end_lt_query is None or max_end_lt_query.time_msecs <= et_query.time_msecs ): max_end_lt_query = et_query elif left == 'duration': - queries.append(BatchDurationQuery.parse(op, right)) + queries.append(JobGroupDurationQuery.parse(op, right)) elif left == 'cost': - queries.append(BatchCostQuery.parse(op, right)) + queries.append(JobGroupCostQuery.parse(op, right)) else: - queries.append(BatchKeywordQuery.parse(op, left, right)) + queries.append(JobGroupKeywordQuery.parse(op, left, right)) else: raise QueryError(f'could not parse term "{_term}"') # this is to make time interval queries fast by using the bounds on both indices if min_start_gt_query and max_end_lt_query and min_start_gt_query.time_msecs <= max_end_lt_query.time_msecs: - queries.append(BatchStartTimeQuery(max_end_lt_query.operator, max_end_lt_query.time_msecs)) - queries.append(BatchEndTimeQuery(min_start_gt_query.operator, min_start_gt_query.time_msecs)) + queries.append(JobGroupStartTimeQuery(max_end_lt_query.operator, max_end_lt_query.time_msecs)) + queries.append(JobGroupEndTimeQuery(min_start_gt_query.operator, min_start_gt_query.time_msecs)) # batch has already been validated - where_conditions = [ - '(billing_project_users.`user` = %s)', - 'NOT deleted', - ] - where_args: List[Any] = [user] + where_conditions = ['(billing_project_users.`user` = %s)', 'NOT deleted', 'job_groups.job_group_id = %s'] + where_args: List[Any] = [user, ROOT_JOB_GROUP_ID] if last_batch_id is not None: - where_conditions.append('(batches.id < %s)') + where_conditions.append('(job_groups.batch_id < %s)') where_args.append(last_batch_id) for query in queries: @@ -127,31 +125,31 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) where_args += args sql = f''' -SELECT batches.*, - job_groups_cancelled.id IS NOT NULL AS cancelled, - job_groups_n_jobs_in_complete_states.n_completed, - job_groups_n_jobs_in_complete_states.n_succeeded, - job_groups_n_jobs_in_complete_states.n_failed, - job_groups_n_jobs_in_complete_states.n_cancelled, - cost_t.cost, cost_t.cost_breakdown -FROM batches +SELECT batches.*, cost_t.cost, cost_t.cost_breakdown, + job_groups_cancelled.id IS NOT NULL AS cancelled, + job_groups_n_jobs_in_complete_states.n_completed, + job_groups_n_jobs_in_complete_states.n_succeeded, + job_groups_n_jobs_in_complete_states.n_failed, + job_groups_n_jobs_in_complete_states.n_cancelled +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id LEFT JOIN billing_projects ON batches.billing_project = billing_projects.name -LEFT JOIN job_groups_n_jobs_in_complete_states ON batches.id = job_groups_n_jobs_in_complete_states.id -LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id +LEFT JOIN job_groups_n_jobs_in_complete_states ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id +LEFT JOIN job_groups_cancelled ON job_groups.batch_id = job_groups_cancelled.id AND job_groups.job_group_id = job_groups_cancelled.job_group_id STRAIGHT_JOIN billing_project_users ON batches.billing_project = billing_project_users.billing_project LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown FROM ( - SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + SELECT batch_id, job_group_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_group_resources_v3 - WHERE batches.id = aggregated_job_group_resources_v3.batch_id - GROUP BY batch_id, resource_id + WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY batch_id, job_group_id, resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id - GROUP BY batch_id + GROUP BY batch_id, job_group_id ) AS cost_t ON TRUE WHERE {' AND '.join(where_conditions)} -ORDER BY id DESC +ORDER BY batches.id DESC LIMIT 51; ''' diff --git a/batch/batch/globals.py b/batch/batch/globals.py index 5014241eac11..0e25f29ea4aa 100644 --- a/batch/batch/globals.py +++ b/batch/batch/globals.py @@ -21,7 +21,7 @@ BATCH_FORMAT_VERSION = 7 STATUS_FORMAT_VERSION = 5 -INSTANCE_VERSION = 26 +INSTANCE_VERSION = 27 MAX_PERSISTENT_SSD_SIZE_GIB = 64 * 1024 RESERVED_STORAGE_GB_PER_CORE = 5 diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index 5fd7dc685864..d983bdfbd7d8 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -1609,6 +1609,10 @@ def __init__( self.project_id = Job.get_next_xfsquota_project_id() + @property + def job_group_id(self): + return self.job_spec['job_group_id'] + @property def job_id(self): return self.job_spec['job_id'] @@ -1736,6 +1740,7 @@ def __init__( {'name': 'HAIL_REGION', 'value': REGION}, {'name': 'HAIL_BATCH_ID', 'value': str(batch_id)}, {'name': 'HAIL_JOB_ID', 'value': str(self.job_id)}, + {'name': 'HAIL_JOB_GROUP_ID', 'value': str(self.job_group_id)}, {'name': 'HAIL_ATTEMPT_ID', 'value': str(self.attempt_id)}, {'name': 'HAIL_IDENTITY_PROVIDER_JSON', 'value': json.dumps(self.credentials.identity_provider_json)}, ] @@ -3076,6 +3081,7 @@ async def create_job_1(self, request): job_spec = await self.file_store.read_spec_file(batch_id, token, start_job_id, job_id) job_spec = json.loads(job_spec) + job_spec['job_group_id'] = addtl_spec['job_group_id'] job_spec['attempt_id'] = addtl_spec['attempt_id'] job_spec['secrets'] = addtl_spec['secrets']