Skip to content

Commit

Permalink
Filter out malformed nvidia-smi process_name XML tag
Browse files Browse the repository at this point in the history
  • Loading branch information
jfennick committed Sep 19, 2023
1 parent 20f01e0 commit 7945f7d
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions cwltool/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,37 @@

def cuda_version_and_device_count() -> Tuple[str, int]:
"""Determine the CUDA version and number of attached CUDA GPUs."""
# For the number of GPUs, we can use the following query
cmd_count = ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader"]
try:
out = subprocess.check_output(["nvidia-smi", "-q", "-x"]) # nosec
out_count = subprocess.check_output(cmd_count) # nosec
except Exception as e:
_logger.warning("Error checking number of GPUs with nvidia-smi: %s", e)
return ("", 0)
count = int(out_count)

# Since there is no specific query for the cuda version, we have to use
# `nvidia-smi -q -x`
# However, apparently nvidia-smi is not safe to call concurrently.
# With --parallel, sometimes the returned XML will contain
# <process_name>\xff...\xff</process_name>
# (or other arbitrary bytes) and xml.dom.minidom.parseString will raise
# "xml.parsers.expat.ExpatError: not well-formed (invalid token)"
# So we either need to fix the process_name tag, or better yet specifically
# `grep cuda_version`
cmd_cuda_version = "nvidia-smi -q -x | grep cuda_version"
try:
out = subprocess.check_output(cmd_cuda_version, shell=True) # nosec
except Exception as e:
_logger.warning("Error checking CUDA version with nvidia-smi: %s", e)
return ("", 0)
dm = xml.dom.minidom.parseString(out) # nosec

ag = dm.getElementsByTagName("attached_gpus")
if len(ag) < 1 or ag[0].firstChild is None:
_logger.warning(
"Error checking CUDA version with nvidia-smi. Missing 'attached_gpus' or it is empty.: %s",
out,
)
try:
dm = xml.dom.minidom.parseString(out) # nosec
except xml.parsers.expat.ExpatError as e:
_logger.warning("Error parsing XML stdout of nvidia-smi: %s", e)
_logger.warning("stdout: %s", out)
return ("", 0)
ag_element = ag[0].firstChild

cv = dm.getElementsByTagName("cuda_version")
if len(cv) < 1 or cv[0].firstChild is None:
Expand All @@ -35,13 +51,11 @@ def cuda_version_and_device_count() -> Tuple[str, int]:
return ("", 0)
cv_element = cv[0].firstChild

if isinstance(cv_element, xml.dom.minidom.Text) and isinstance(
ag_element, xml.dom.minidom.Text
):
return (cv_element.data, int(ag_element.data))
if isinstance(cv_element, xml.dom.minidom.Text):
return (cv_element.data, count)
_logger.warning(
"Error checking CUDA version with nvidia-smi. "
"Either 'attached_gpus' or 'cuda_version' was not a text node: %s",
"'cuda_version' was not a text node: %s",
out,
)
return ("", 0)
Expand Down

0 comments on commit 7945f7d

Please sign in to comment.