diff --git a/jobrunner/cli/kill_job.py b/jobrunner/cli/kill_job.py index 521f4871..c93ccdd3 100644 --- a/jobrunner/cli/kill_job.py +++ b/jobrunner/cli/kill_job.py @@ -55,20 +55,25 @@ def get_jobs(partial_job_ids): jobs = [] need_confirmation = False for partial_job_id in partial_job_ids: - matches = database.find_where(Job, id__like=f"%{partial_job_id}%") - if len(matches) == 0: + # look for partial matches + partial_matches = database.find_where(Job, id__like=f"%{partial_job_id}%") + if len(partial_matches) == 0: raise RuntimeError(f"No jobs found matching '{partial_job_id}'") - elif len(matches) > 1: + elif len(partial_matches) > 1: print(f"Multiple jobs found matching '{partial_job_id}':") - for i, job in enumerate(matches, start=1): + for i, job in enumerate(partial_matches, start=1): print(f" {i}: {job.slug}") print() index = int(input("Enter number: ")) - assert 0 < index <= len(matches) - jobs.append(matches[index - 1]) + assert 0 < index <= len(partial_matches) + jobs.append(partial_matches[index - 1]) else: - need_confirmation = True - jobs.append(matches[0]) + # We only need confirmation if the supplied job ID doesn't exactly + # match the found job + job = partial_matches[0] + if job.id != partial_job_id: + need_confirmation = True + jobs.append(job) if need_confirmation: print("About to kill jobs:") for job in jobs: diff --git a/tests/cli/test_kill_job.py b/tests/cli/test_kill_job.py index 0b5dc52d..1389985f 100644 --- a/tests/cli/test_kill_job.py +++ b/tests/cli/test_kill_job.py @@ -9,6 +9,120 @@ from tests.factories import job_factory +def test_get_jobs_no_jobs(db): + + # set a string to use as a partial id + partial_job_id = "1234" + partial_job_ids = [partial_job_id] + + with pytest.raises(RuntimeError): + kill_job.get_jobs(partial_job_ids) + + +def test_get_jobs_no_match(db): + + # make a fake job + job_factory( + state=State.RUNNING, status_code=StatusCode.EXECUTING, id="z6tkp3mjato63dkm" + ) + + partial_job_id = "1234" + partial_job_ids = [partial_job_id] + + with pytest.raises(RuntimeError): + kill_job.get_jobs(partial_job_ids) + + +def test_get_jobs_multiple_matches(db, monkeypatch): + + # make a fake job + job = job_factory( + state=State.RUNNING, status_code=StatusCode.EXECUTING, id="z6tkp3mjato63dkm" + ) + + job_factory( + state=State.RUNNING, status_code=StatusCode.EXECUTING, id="z6tkp3mjato63dkn" + ) + + partial_job_id = "kp3mj" + partial_job_ids = [partial_job_id] + + monkeypatch.setattr("builtins.input", lambda _: "1") + + output_job_ids = kill_job.get_jobs(partial_job_ids) + + assert output_job_ids[0].id == job.id + + +def test_get_jobs_multiple_params_partial(db, monkeypatch): + + job1 = job_factory( + state=State.RUNNING, status_code=StatusCode.EXECUTING, id="z6tkp3mjato63dkm" + ) + + job2 = job_factory( + state=State.RUNNING, status_code=StatusCode.EXECUTING, id="z6tkp3mjato63dkn" + ) + + partial_job_ids = ["dkm", "dkn"] + + monkeypatch.setattr("builtins.input", lambda _: "") + + # search for jobs with our partial id + output_job_ids = kill_job.get_jobs(partial_job_ids) + + assert output_job_ids[0].id == job1.id + assert output_job_ids[1].id == job2.id + + +def test_get_jobs_partial_id(db, monkeypatch): + # make a fake job + job = job_factory(state=State.RUNNING, status_code=StatusCode.EXECUTING) + + # take the first four characters to make a partial id + partial_job_id = job.id[:4] + partial_job_ids = [partial_job_id] + + monkeypatch.setattr("builtins.input", lambda _: "") + + # search for jobs with our partial id + output_job_ids = kill_job.get_jobs(partial_job_ids) + + assert output_job_ids[0].id == job.id + + +def test_get_jobs_partial_id_quit(db, monkeypatch): + # make a fake job + job = job_factory(state=State.RUNNING, status_code=StatusCode.EXECUTING) + + # take the first four characters to make a partial id + partial_job_id = job.id[:4] + partial_job_ids = [partial_job_id] + + def press_control_c(_): + raise KeyboardInterrupt() + + monkeypatch.setattr("builtins.input", press_control_c) + + # make sure the program is quit + with pytest.raises(KeyboardInterrupt): + kill_job.get_jobs(partial_job_ids) + + +def test_get_jobs_full_id(db): + # make a fake job + job = job_factory(state=State.RUNNING, status_code=StatusCode.EXECUTING) + + # this "partial id" is secretly a full id!! + full_job_id = job.id + full_job_ids = [full_job_id] + + # search for jobs with our partial id + output_job_ids = kill_job.get_jobs(full_job_ids) + + assert output_job_ids[0].id == job.id + + @pytest.mark.needs_docker @pytest.mark.parametrize("cleanup", [False, True]) def test_kill_job(cleanup, tmp_work_dir, db, monkeypatch):