From a20ea65e32db17d698465e5f04bfe266a368ac26 Mon Sep 17 00:00:00 2001 From: Alan Malta Rodrigues Date: Wed, 16 Aug 2023 06:05:32 -0400 Subject: [PATCH] Update WMTask to return a unique list of CUDARuntime --- src/python/WMCore/WMSpec/WMTask.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/python/WMCore/WMSpec/WMTask.py b/src/python/WMCore/WMSpec/WMTask.py index 249eadf9e8e..c3ed99dd57a 100644 --- a/src/python/WMCore/WMSpec/WMTask.py +++ b/src/python/WMCore/WMSpec/WMTask.py @@ -1525,16 +1525,32 @@ def getRequiresGPU(self): def getGPURequirements(self): """ Return the GPU requirements for this task. - If it's a multi-step task, the first step with a meaningful - dictionary value will be returned + For multi-step tasks, the following logic is applied: + * GPUMemoryMB: return the max of them + * CUDARuntime: returns a flat list of unique runtime versions + * CUDACapabilities: returns a flat list of unique capabilities :return: a dictionary with the GPU requirements for this task """ - gpuRequirements = {} + gpuRequirements = [] for stepName in sorted(self.listAllStepNames()): stepHelper = self.getStep(stepName) if stepHelper.stepType() == "CMSSW" and stepHelper.getGPURequirements(): - return stepHelper.getGPURequirements() - return gpuRequirements + gpuRequirements.append(stepHelper.getGPURequirements()) + if not gpuRequirements: + return {} + + # in this case, it requires GPUs and it can be multi-steps GPU + bestGPUParams = gpuRequirements.pop(0) + bestGPUParams["CUDARuntime"] = [bestGPUParams["CUDARuntime"]] + for params in gpuRequirements: + if params["GPUMemoryMB"] > bestGPUParams["GPUMemoryMB"]: + bestGPUParams["GPUMemoryMB"] = params["GPUMemoryMB"] + bestGPUParams["CUDARuntime"].append(params["CUDARuntime"]) + bestGPUParams["CUDACapabilities"].extend(params["CUDACapabilities"]) + # make the flat list elements unique + bestGPUParams["CUDARuntime"] = list(set(bestGPUParams["CUDARuntime"])) + bestGPUParams["CUDACapabilities"] = list(set(bestGPUParams["CUDACapabilities"])) + return bestGPUParams def _getStepValue(self, keyDict, defaultValue): """