diff --git a/src/DIRAC/WorkloadManagementSystem/FutureClient/JobStateUpdateClient.py b/src/DIRAC/WorkloadManagementSystem/FutureClient/JobStateUpdateClient.py index f8f662d7b30..4a138c652df 100644 --- a/src/DIRAC/WorkloadManagementSystem/FutureClient/JobStateUpdateClient.py +++ b/src/DIRAC/WorkloadManagementSystem/FutureClient/JobStateUpdateClient.py @@ -1,50 +1,128 @@ +import functools +from datetime import datetime, timezone + + from DIRAC.Core.Security.DiracX import DiracXClient from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue +from DIRAC.Core.Utilities.TimeUtilities import fromString + + +def stripValueIfOK(func): + """Decorator to remove S_OK["Value"] from the return value of a function if it is OK. + + This is done as some update functions return the number of modified rows in + the database. This likely not actually useful so it isn't supported in + DiracX. Stripping the "Value" key of the dictionary means that we should + get a fairly straight forward error if the assumption is incorrect. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + if result.get("OK"): + assert result.pop("Value") is None, "Value should be None if OK" + return result + + return wrapper class JobStateUpdateClient: def sendHeartBeat(self, jobID: str | int, dynamicData: dict, staticData: dict): raise NotImplementedError("TODO") + @stripValueIfOK + @convertToReturnValue def setJobApplicationStatus(self, jobID: str | int, appStatus: str, source: str = "Unknown"): - raise NotImplementedError("TODO") + statusDict = { + "application_status": appStatus, + } + if source: + statusDict["Source"] = source + with DiracXClient() as api: + api.jobs.set_single_job_status( + jobID, + {datetime.now(tz=timezone.utc): statusDict}, + ) + @stripValueIfOK + @convertToReturnValue def setJobAttribute(self, jobID: str | int, attribute: str, value: str): with DiracXClient() as api: - api.jobs.set_single_job_properties(jobID, "need to [patch the client to have a nice summer body ?") - raise NotImplementedError("TODO") + if attribute == "Status": + api.jobs.set_single_job_status( + jobID, + {datetime.now(tz=timezone.utc): {"status": value}}, + ) + else: + api.jobs.set_single_job_properties(jobID, {attribute: value}) + @stripValueIfOK + @convertToReturnValue def setJobFlag(self, jobID: str | int, flag: str): - raise NotImplementedError("TODO") + with DiracXClient() as api: + api.jobs.set_single_job_properties(jobID, {flag: True}) + @stripValueIfOK + @convertToReturnValue def setJobParameter(self, jobID: str | int, name: str, value: str): - raise NotImplementedError("TODO") + print("HACK: This is a no-op until we decide what to do") + @stripValueIfOK + @convertToReturnValue def setJobParameters(self, jobID: str | int, parameters: list): - raise NotImplementedError("TODO") + print("HACK: This is a no-op until we decide what to do") + @stripValueIfOK + @convertToReturnValue def setJobSite(self, jobID: str | int, site: str): - raise NotImplementedError("TODO") + with DiracXClient() as api: + api.jobs.set_single_job_properties(jobID, {"Site": site}) + @stripValueIfOK + @convertToReturnValue def setJobStatus( self, jobID: str | int, status: str = "", minorStatus: str = "", source: str = "Unknown", - datetime=None, + datetime_=None, force=False, ): - raise NotImplementedError("TODO") + statusDict = {} + if status: + statusDict["Status"] = status + if minorStatus: + statusDict["MinorStatus"] = minorStatus + if source: + statusDict["Source"] = source + if datetime_ is None: + datetime_ = datetime.utcnow() + with DiracXClient() as api: + api.jobs.set_single_job_status( + jobID, + {fromString(datetime_).replace(tzinfo=timezone.utc): statusDict}, + force=force, + ) + @stripValueIfOK + @convertToReturnValue def setJobStatusBulk(self, jobID: str | int, statusDict: dict, force=False): - raise NotImplementedError("TODO") + statusDict = {fromString(k).replace(tzinfo=timezone.utc): v for k, v in statusDict.items()} + with DiracXClient() as api: + api.jobs.set_job_status_bulk( + {jobID: statusDict}, + force=force, + ) def setJobsParameter(self, jobsParameterDict: dict): raise NotImplementedError("TODO") + @stripValueIfOK + @convertToReturnValue def unsetJobFlag(self, jobID: str | int, flag: str): - raise NotImplementedError("TODO") + with DiracXClient() as api: + api.jobs.set_single_job_properties(jobID, {flag: False}) def updateJobFromStager(self, jobID: str | int, status: str): raise NotImplementedError("TODO") diff --git a/tests/Integration/FutureClient/WorkloadManagement/Test_JobStateUpdate.py b/tests/Integration/FutureClient/WorkloadManagement/Test_JobStateUpdate.py index c9fbf920367..2ddebce69d3 100644 --- a/tests/Integration/FutureClient/WorkloadManagement/Test_JobStateUpdate.py +++ b/tests/Integration/FutureClient/WorkloadManagement/Test_JobStateUpdate.py @@ -1,12 +1,45 @@ +from datetime import datetime from functools import partial +from textwrap import dedent import pytest import DIRAC DIRAC.initialize() +from DIRAC.Core.Security.DiracX import DiracXClient from DIRAC.WorkloadManagementSystem.Client.JobStateUpdateClient import JobStateUpdateClient -from ..utils import compare_results +from ..utils import compare_results2 + +test_jdl = """ +Arguments = "Hello world from DiracX"; +Executable = "echo"; +JobGroup = jobGroup; +JobName = jobName; +JobType = User; +LogLevel = INFO; +MinNumberOfProcessors = 1000; +OutputSandbox = + { + std.err, + std.out + }; +Priority = 1; +Sites = ANY; +StdError = std.err; +StdOutput = std.out; +""" + + +@pytest.fixture() +def example_jobids(): + from DIRAC.Interfaces.API.Dirac import Dirac + from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise + + d = Dirac() + job_id_1 = returnValueOrRaise(d.submitJob(test_jdl)) + job_id_2 = returnValueOrRaise(d.submitJob(test_jdl)) + return job_id_1, job_id_2 def test_sendHeartBeat(monkeypatch): @@ -15,16 +48,22 @@ def test_sendHeartBeat(monkeypatch): pytest.skip() -def test_setJobApplicationStatus(monkeypatch): +def test_setJobApplicationStatus(monkeypatch, example_jobids): # JobStateUpdateClient().setJobApplicationStatus(jobID: str | int, appStatus: str, source: str = Unknown) method = JobStateUpdateClient().setJobApplicationStatus - pytest.skip() + args = ["MyApplicationStatus"] + test_func1 = partial(method, example_jobids[0], *args) + test_func2 = partial(method, example_jobids[1], *args) + compare_results2(monkeypatch, test_func1, test_func2) -def test_setJobAttribute(monkeypatch): +@pytest.mark.parametrize("args", [["Status", "Killed"], ["JobGroup", "newJobGroup"]]) +def test_setJobAttribute(monkeypatch, example_jobids, args): # JobStateUpdateClient().setJobAttribute(jobID: str | int, attribute: str, value: str) method = JobStateUpdateClient().setJobAttribute - pytest.skip() + test_func1 = partial(method, example_jobids[0], *args) + test_func2 = partial(method, example_jobids[1], *args) + compare_results2(monkeypatch, test_func1, test_func2) def test_setJobFlag(monkeypatch): @@ -45,22 +84,37 @@ def test_setJobParameters(monkeypatch): pytest.skip() -def test_setJobSite(monkeypatch): +@pytest.mark.parametrize("jobid_type", [int, str]) +def test_setJobSite(monkeypatch, example_jobids, jobid_type): # JobStateUpdateClient().setJobSite(jobID: str | int, site: str) method = JobStateUpdateClient().setJobSite - pytest.skip() + args = ["LCG.CERN.ch"] + test_func1 = partial(method, jobid_type(example_jobids[0]), *args) + test_func2 = partial(method, jobid_type(example_jobids[1]), *args) + compare_results2(monkeypatch, test_func1, test_func2) -def test_setJobStatus(monkeypatch): +def test_setJobStatus(monkeypatch, example_jobids): # JobStateUpdateClient().setJobStatus(jobID: str | int, status: str = , minorStatus: str = , source: str = Unknown, datetime = None, force = False) method = JobStateUpdateClient().setJobStatus - pytest.skip() + args = ["", "My Minor"] + test_func1 = partial(method, example_jobids[0], *args) + test_func2 = partial(method, example_jobids[1], *args) + compare_results2(monkeypatch, test_func1, test_func2) -def test_setJobStatusBulk(monkeypatch): +def test_setJobStatusBulk(monkeypatch, example_jobids): # JobStateUpdateClient().setJobStatusBulk(jobID: str | int, statusDict: dict, force = False) method = JobStateUpdateClient().setJobStatusBulk - pytest.skip() + args = [ + { + datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"): {"ApplicationStatus": "SomethingElse"}, + datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"): {"ApplicationStatus": "Something"}, + } + ] + test_func1 = partial(method, example_jobids[0], *args) + test_func2 = partial(method, example_jobids[1], *args) + compare_results2(monkeypatch, test_func1, test_func2) def test_setJobsParameter(monkeypatch): diff --git a/tests/Integration/FutureClient/utils.py b/tests/Integration/FutureClient/utils.py index 9fb9afdb2b5..fe1b08636b9 100644 --- a/tests/Integration/FutureClient/utils.py +++ b/tests/Integration/FutureClient/utils.py @@ -1,20 +1,47 @@ -def compare_results(test_func): +import time + + +def compare_results(monkeypatch, test_func): """Compare the results from DIRAC and DiracX based services for a reentrant function.""" - ClientClass = test_func.func.__self__ - assert ClientClass.diracxClient, "FutureClient is not set up!" + compare_results2(monkeypatch, test_func, test_func) + +def compare_results2(monkeypatch, test_func1, test_func2): + """Compare the results from DIRAC and DiracX based services for two functions which should behave identically.""" # Get the result from the diracx-based handler - future_result = test_func() + start = time.monotonic() + with monkeypatch.context() as m: + m.setattr("DIRAC.Core.Tornado.Client.ClientSelector.useLegacyAdapter", lambda *_: True) + try: + future_result = test_func1() + except Exception as e: + future_result = e + else: + assert "rpcStub" not in future_result, "rpcStub should never be present when using DiracX!" + diracx_duration = time.monotonic() - start # Get the result from the DIRAC-based handler - diracxClient = ClientClass.diracxClient - ClientClass.diracxClient = None - try: - old_result = test_func() - finally: - ClientClass.diracxClient = diracxClient - # We don't care about the rpcStub + start = time.monotonic() + with monkeypatch.context() as m: + m.setattr("DIRAC.Core.Tornado.Client.ClientSelector.useLegacyAdapter", lambda *_: False) + old_result = test_func2() + assert "rpcStub" in old_result, "rpcStub should always be present when using legacy DIRAC!" + legacy_duration = time.monotonic() - start + + # We don't care about the rpcStub or Errno old_result.pop("rpcStub") + old_result.pop("Errno", None) + + if not old_result["OK"]: + assert not future_result["OK"], "FutureClient should have failed too!" + elif "Value" in future_result: + # Ensure the results match exactly + assert old_result == future_result + else: + # See the "stripValueIfOK" decorator for explanation + assert old_result["OK"] == future_result["OK"] + # assert isinstance(old_result["Value"], int) - # Ensure the results match - assert old_result == future_result + # if 3 * legacy_duration < diracx_duration: + # print(f"Legacy DIRAC took {legacy_duration:.3f}s, FutureClient took {diracx_duration:.3f}s") + # assert False, "FutureClient should be faster than legacy DIRAC!"