Skip to content

Commit

Permalink
New Site.acceleratorParams as json-like dictionary (#7470)
Browse files Browse the repository at this point in the history
* accelerator params

* fix keyerror and change request gpu back to request_GPUs

Co-authored-by: Thanayut Seethongchuen <=>
  • Loading branch information
novicecpp authored Dec 5, 2022
1 parent d1fc2b2 commit 452dc41
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 10 deletions.
20 changes: 16 additions & 4 deletions src/python/CRABInterface/RESTUserWorkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from CRABInterface.Regexps import (RX_TASKNAME, RX_ACTIVITY, RX_JOBTYPE, RX_GENERATOR, RX_LUMIEVENTS, RX_CMSSW, RX_ARCH, RX_DATASET,
RX_CMSSITE, RX_SPLIT, RX_CACHENAME, RX_CACHEURL, RX_LFN, RX_USERFILE, RX_VOPARAMS, RX_DBSURL, RX_LFNPRIMDS, RX_OUTFILES,
RX_RUNS, RX_LUMIRANGE, RX_SCRIPTARGS, RX_SCHEDD_NAME, RX_COLLECTOR, RX_SUBRESTAT, RX_JOBID, RX_ADDFILE,
RX_ANYTHING, RX_USERNAME, RX_DATE, RX_MANYLINES_SHORT)
from CRABInterface.Utilities import CMSSitesCache, conn_handler, getDBinstance
RX_ANYTHING, RX_USERNAME, RX_DATE, RX_MANYLINES_SHORT, RX_CUDA_VERSION)
from CRABInterface.Utilities import CMSSitesCache, conn_handler, getDBinstance, validate_dict
from ServerUtilities import checkOutLFN, generateTaskName


Expand Down Expand Up @@ -417,6 +417,16 @@ def validate(self, apiobj, method, api, param, safe): #pylint: disable=unused-ar
validate_num("ignoreglobalblacklist", param, safe, optional=True)
validate_num("partialdataset", param, safe, optional=True)
validate_num("requireaccelerator", param, safe, optional=True)
# validate optional acceleratorparams
if param.kwargs.get("acceleratorparams", None):
if not safe.kwargs["requireaccelerator"]:
raise InvalidParameter("There are accelerator parameters but requireAccelerator is False")
with validate_dict("acceleratorparams", param, safe) as (accParams, accSafe):
validate_num("GPUMemoryMB", accParams, accSafe, minval=0, optional=True)
validate_strlist("CUDACapabilities", accParams, accSafe, RX_CUDA_VERSION)
validate_str("CUDARuntime", accParams, accSafe, RX_CUDA_VERSION, optional=True)
else:
safe.kwargs["acceleratorparams"] = None

elif method in ['POST']:
validate_str("workflow", param, safe, RX_TASKNAME, optional=False)
Expand Down Expand Up @@ -488,7 +498,7 @@ def put(self, workflow, activity, jobtype, jobsw, jobarch, inputdata, primarydat
tfileoutfiles, edmoutfiles, runs, lumis,
totalunits, adduserfiles, oneEventMode, maxjobruntime, numcores, maxmemory, priority, blacklistT1, nonprodsw, lfn, saveoutput,
faillimit, ignorelocality, userfiles, scriptexe, scriptargs, scheddname, extrajdl, collector, dryrun, ignoreglobalblacklist,
partialdataset, requireaccelerator):
partialdataset, requireaccelerator, acceleratorparams):
"""Perform the workflow injection
:arg str workflow: request name defined by the user;
Expand Down Expand Up @@ -545,9 +555,11 @@ def put(self, workflow, activity, jobtype, jobsw, jobarch, inputdata, primarydat

user_config = {
'partialdataset': True if partialdataset else False,
'requireaccelerator': True if requireaccelerator else False
'requireaccelerator': True if requireaccelerator else False,
'acceleratorparams': acceleratorparams if acceleratorparams else None,
}


return self.userworkflowmgr.submit(workflow=workflow, activity=activity, jobtype=jobtype, jobsw=jobsw, jobarch=jobarch,
inputdata=inputdata, primarydataset=primarydataset, nonvaliddata=nonvaliddata, use_parent=useparent,
secondarydata=secondarydata, generator=generator, events_per_lumi=eventsperlumi,
Expand Down
1 change: 1 addition & 0 deletions src/python/CRABInterface/RESTWorkerWorkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def fixupTask(task):
user_config_default = {
'partialdataset': False,
'requireaccelerator': False,
'accceleratorparams': None,
}
if result['tm_user_config']:
result['tm_user_config'] = json.loads(result['tm_user_config'])
Expand Down
3 changes: 3 additions & 0 deletions src/python/CRABInterface/Regexps.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,6 @@

RX_SUBGETUSERTRANSFER = re.compile(r"^(getById|getTransferStatus|getPublicationStatus)$")
RX_SUBPOSTUSERTRANSFER = re.compile(r"^(killTransfers|retryPublication|retryTransfers|killTransfersById|updateDoc)$")

# CUDAVersion style, i.e. 11.4, 515.43.04
RX_CUDA_VERSION = re.compile(r"^\d+\.\d+(\.\d+)?$")
66 changes: 62 additions & 4 deletions src/python/CRABInterface/Utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import print_function
import logging
import os
from contextlib import contextmanager, nullcontext
from collections import namedtuple
from time import mktime, gmtime
import re
Expand All @@ -10,11 +11,13 @@
import pycurl
import io
import json
import copy

from WMCore.WMFactory import WMFactory
from WMCore.REST.Error import ExecutionError, InvalidParameter
from WMCore.Services.CRIC.CRIC import CRIC
from WMCore.Services.pycurl_manager import ResponseHeader
from WMCore.REST.Server import RESTArgs

from Utils.Utilities import encodeUnicodeToBytes

Expand Down Expand Up @@ -113,9 +116,9 @@ def retrieveConfig(externalLink):
return centralCfgFallback
else:
cherrypy.log(msg)
raise ExecutionError("Internal issue when retrieving external configuration from %s" % externalLink)
jsonConfig = bbuf.getvalue()
raise ExecutionError("Internal issue when retrieving external configuration from %s" % externalLink)
jsonConfig = bbuf.getvalue()

return jsonConfig

extConfCommon = json.loads(retrieveConfig(extconfigurl))
Expand Down Expand Up @@ -151,7 +154,7 @@ def retrieveConfig(externalLink):
else:
extConfCommon["backend-urls"]["htcondorSchedds"] = extConfSchedds
centralCfgFallback = extConfCommon

return centralCfgFallback


Expand All @@ -174,3 +177,58 @@ def wrapped_func(*args, **kwargs):
return wrapped_func
return wrap


@contextmanager
def validate_dict(argname, param, safe, maxjsonsize=1024):
"""
Provide context manager to validate kv of DictType argument.
validate_dict first checks that if an argument named `argname` is
JSON-like dict object, check if json-string exceeds `maxjsonsize`
before deserialize with json.loads()
Then, as contextmanager, validate_dict yield a tuple of RESTArgs
(dictParam, dictSafe) and execute the block nested in "with" statement,
which expected `validate_*` to validate all keys inside json, in the
same way as
`param`/`safe` do in DatabaseRESETApi.validate(), but against
dictParam/dictSafe instead.
If all keys pass validation, the dict object (not the string) is copied
into `safe.kwargs` and the original string value is removed from
`param.kwargs`. If not all keys are validated, it will raise an
exception.
Note that validate_dict itself does not support optional argument.
Example in DatabaseRESTApi.validate() to validate "acceleratorparams"
optional dict parameter with 1 mandatory and 2 optional key
if param.kwargs.get("acceleratorparams", None):
with validate_dict("acceleratorparams", param, safe) as (accParams, accSafe):
custom_err = "Incorrect '{}' parameter. Parameter is also required when Site.requireAccelerator is True"
validate_num("GPUMemoryMB", accParams, accSafe, minval=0, custom_err=custom_err.format("GPUMemoryMB"))
validate_strlist("CUDACapabilities", accParams, accSafe, RX_CUDA_VERSION)
validate_str("CUDARuntime", accParams, accSafe, RX_CUDA_VERSION, optional=True)
else:
safe.kwargs["acceleratorparams"] = None
"""

val = param.kwargs.get(argname, None)
if len(val) > maxjsonsize:
raise InvalidParameter(f"Param is larger than {maxjsonsize} bytes")
try:
data = json.loads(val)
except Exception as e:
raise InvalidParameter("Param is not valid JSON-like dict object") from e
if data is None:
raise InvalidParameter("Param is not defined")
if not isinstance(data, dict):
raise InvalidParameter("Param is not a dictionary encoded as JSON object")
dictParam = RESTArgs([], copy.deepcopy(data))
dictSafe = RESTArgs([], {})
yield (dictParam, dictSafe)
if dictParam.kwargs:
raise InvalidParameter(f"Excess keyword arguments inside keyword argument, not validated kwargs={{'{argname}': {dictParam.kwargs}}}")
safe.kwargs[argname] = data
del param.kwargs[argname]
14 changes: 12 additions & 2 deletions src/python/TaskWorker/Actions/DagmanCreator.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,23 @@ def makeJobSubmit(self, task):
info['accounting_group_user'] = info['userhn']
info = transform_strings(info)
info['faillimit'] = task['tm_fail_limit']
# hardcoding accelerator to GPU (SI currently only have nvidia GPU)
if task['tm_user_config']['requireaccelerator']:
# hardcoding accelerator to GPU (SI currently only have nvidia GPU)
info['accelerator_jdl'] = '+RequiresGPU=1\nrequest_GPUs=1'
if task['tm_user_config']['acceleratorparams']:
gpuMemoryMB = task['tm_user_config']['acceleratorparams'].get('GPUMemoryMB', None)
cudaCapabilities = task['tm_user_config']['acceleratorparams'].get('CUDACapabilities', None)
cudaRuntime = task['tm_user_config']['acceleratorparams'].get('CUDARuntime', None)
if gpuMemoryMB:
info['accelerator_jdl'] += f"\n+GPUMemoryMB={gpuMemoryMB}"
if cudaCapabilities:
cudaCapability = ','.join(sorted(cudaCapabilities))
info['accelerator_jdl'] += f"\n+CUDACapability={classad.quote(cudaCapability)}"
if cudaRuntime:
info['accelerator_jdl'] += f"\n+CUDARuntime={classad.quote(cudaRuntime)}"
else:
info['accelerator_jdl'] = ''
info['extra_jdl'] = '\n'.join(literal_eval(task['tm_extrajdl']))

# info['jobarch_flatten'].split("_")[0]: extracts "slc7" from "slc7_amd64_gcc10"
required_os_list = ARCH_TO_OS.get(info['jobarch_flatten'].split("_")[0])
# ARCH_TO_OS.get("slc7") gives a list with one item only: ['rhel7']
Expand Down
4 changes: 4 additions & 0 deletions src/python/TaskWorker/Actions/DagmanSubmitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def addCRABInfoToClassAd(ad, info):
for jdl in info['extra_jdl'].split('\n'):
adName, adVal = jdl.lstrip('+').split('=', 1)
ad[adName] = adVal
if 'accelerator_jdl' in info and info['accelerator_jdl']:
for jdl in info['accelerator_jdl'].split('\n'):
adName, adVal = jdl.lstrip('+').split('=', 1)
ad[adName] = classad.ExprTree(str(adVal))


class ScheddStats(dict):
Expand Down

0 comments on commit 452dc41

Please sign in to comment.