1818import shlex
1919import subprocess
2020import tempfile
21+ import warnings
2122from dataclasses import dataclass
2223from datetime import datetime
2324from subprocess import CalledProcessError , PIPE
@@ -72,6 +73,55 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
7273 return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
7374
7475
76+ def version () -> Tuple [int , int ]:
77+ """
78+ Uses ``sinfo --version`` to get the slurm version. If the command fails, it
79+ assumes the version is ``slurm 24.05.8``.
80+
81+ Returns:
82+ -------
83+ Tuple[int, int] slurm version as a tuple of ints (major, minor).
84+ """
85+
86+ cmd = ["sinfo" , "--version" ]
87+ try :
88+ out = subprocess .check_output (cmd , stderr = PIPE , encoding = "utf-8" )
89+ except (CalledProcessError , FileNotFoundError ):
90+ out = "slurm 24.05.8"
91+ warnings .warn (
92+ "Error running: `{sinfo_cmd}` to get SLURM version. Are you running outside the "
93+ "cluster's login or head node? This typically happens when running in `--dryrun`"
94+ " mode. Assuming version is `slurm 24.05.8`." ,
95+ RuntimeWarning ,
96+ stacklevel = 2 ,
97+ )
98+
99+ # sinfo --version returns in the form "slurm 24.1.0"
100+ _ , version_literal = out .split (" " , maxsplit = 2 )
101+ major , minor = [int (v ) for v in version_literal .split ("." )][:2 ]
102+
103+ return (major , minor )
104+
105+
106+ def _should_use_gpus_per_node_from_version () -> bool :
107+ """
108+ Determine whether to use gpus-per-node based on automatically detected slurm version.
109+
110+ Change Reference: https://fburl.com/sqwqzxn6
111+ > select/linear - Reject jobs asking for GRES per job|socket|task or cpus|mem per GRES.
112+
113+ Returns:
114+ ``True`` in slurm ``version>=24.11.0``, ``False`` otherwise.
115+ """
116+
117+ slurm_24_11_0 = (24 , 11 )
118+ slurm_version = version ()
119+
120+ return slurm_version [0 ] > slurm_24_11_0 [0 ] or ( # Major version is greater
121+ slurm_version [0 ] == slurm_24_11_0 [0 ] and slurm_version [1 ] >= slurm_24_11_0 [1 ]
122+ ) # Major version is equal and minor version is greater or equal
123+
124+
75125SBATCH_JOB_OPTIONS = {
76126 "comment" ,
77127 "mail-user" ,
@@ -81,6 +131,7 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
81131 "partition" ,
82132 "time" ,
83133 "constraint" ,
134+ "qos" ,
84135}
85136
86137log : logging .Logger = logging .getLogger (__name__ )
@@ -106,6 +157,7 @@ def _apply_app_id_env(s: str) -> str:
106157 "mail-user" : Optional [str ],
107158 "mail-type" : Optional [str ],
108159 "job_dir" : Optional [str ],
160+ "qos" : Optional [str ],
109161 },
110162 total = False ,
111163)
@@ -126,7 +178,11 @@ class SlurmReplicaRequest:
126178
127179 @classmethod
128180 def from_role (
129- cls , name : str , role : Role , cfg : SlurmOpts , nomem : bool
181+ cls ,
182+ name : str ,
183+ role : Role ,
184+ cfg : SlurmOpts ,
185+ nomem : bool ,
130186 ) -> "SlurmReplicaRequest" :
131187 """
132188 ``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +205,11 @@ def from_role(
149205 if not nomem and resource .memMB > 0 :
150206 sbatch_opts .setdefault ("mem" , str (resource .memMB ))
151207 if resource .gpu > 0 :
152- sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
208+ # Use smart GPU allocation based on automatically detected Slurm version
209+ if _should_use_gpus_per_node_from_version ():
210+ sbatch_opts .setdefault ("gpus-per-node" , str (resource .gpu ))
211+ else :
212+ sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
153213
154214 srun_opts = {
155215 "output" : f"slurm-{ macros .app_id } -{ name } .out" ,
@@ -378,6 +438,11 @@ def _run_opts(self) -> runopts:
378438 iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379439 """ ,
380440 )
441+ opts .add (
442+ "qos" ,
443+ type_ = str ,
444+ help = "Quality of Service (QoS) to assign to the job." ,
445+ )
381446 return opts
382447
383448 def schedule (self , dryrun_info : AppDryRunInfo [SlurmBatchRequest ]) -> str :
0 commit comments