diff --git a/torchx/schedulers/local_scheduler.py b/torchx/schedulers/local_scheduler.py index a7c3cbfa7..73fadd73e 100644 --- a/torchx/schedulers/local_scheduler.py +++ b/torchx/schedulers/local_scheduler.py @@ -27,7 +27,17 @@ from dataclasses import asdict, dataclass from datetime import datetime from types import FrameType -from typing import Any, BinaryIO, Callable, Dict, Iterable, List, Optional, TextIO +from typing import ( + Any, + BinaryIO, + Callable, + Dict, + Iterable, + List, + Optional, + TextIO, + Tuple, +) from torchx.schedulers.api import ( AppDryRunInfo, @@ -658,29 +668,11 @@ def _popen( as file name ``str`` rather than a file-like obj. """ - stdout_ = self._get_file_io(replica_params.stdout) - stderr_ = self._get_file_io(replica_params.stderr) - combined_: Optional[Tee] = None - combined_file = self._get_file_io(replica_params.combined) - if combined_file: - combined_ = Tee( - combined_file, - none_throws(replica_params.stdout), - none_throws(replica_params.stderr), - ) - - # inherit parent's env vars since 99.9% of the time we want this behavior - # just make sure we override the parent's env vars with the user_defined ones - env = os.environ.copy() - env.update(replica_params.env) - # PATH is a special one, instead of overriding, append - env["PATH"] = _join_PATH(replica_params.env.get("PATH"), os.getenv("PATH")) - - # default to unbuffered python for faster responsiveness locally - env.setdefault("PYTHONUNBUFFERED", "x") + stdout_, stderr_, combined_ = self._get_replica_output_handles(replica_params) args_pfmt = pprint.pformat(asdict(replica_params), indent=2, width=80) log.debug(f"Running {role_name} (replica {replica_id}):\n {args_pfmt}") + env = self._get_replica_env(replica_params) proc = subprocess.Popen( args=replica_params.args, @@ -700,6 +692,47 @@ def _popen( error_file=env.get("TORCHELASTIC_ERROR_FILE", ""), ) + def _get_replica_output_handles( + self, + replica_params: ReplicaParam, + ) -> Tuple[Optional[io.FileIO], Optional[io.FileIO], Optional[Tee]]: + """ + Returns the stdout, stderr, and combined outputs of the replica. + If the combined output file is not specified, then the combined output is ``None``. + """ + + stdout_ = self._get_file_io(replica_params.stdout) + stderr_ = self._get_file_io(replica_params.stderr) + combined_: Optional[Tee] = None + combined_file = self._get_file_io(replica_params.combined) + if combined_file: + combined_ = Tee( + combined_file, + none_throws(replica_params.stdout), + none_throws(replica_params.stderr), + ) + return stdout_, stderr_, combined_ + + def _get_replica_env( + self, + replica_params: ReplicaParam, + ) -> Dict[str, str]: + """ + Returns environment variables for the ``_LocalReplica`` + """ + + # inherit parent's env vars since 99.9% of the time we want this behavior + # just make sure we override the parent's env vars with the user_defined ones + env = os.environ.copy() + env.update(replica_params.env) + # PATH is a special one, instead of overriding, append + env["PATH"] = _join_PATH(replica_params.env.get("PATH"), os.getenv("PATH")) + + # default to unbuffered python for faster responsiveness locally + env.setdefault("PYTHONUNBUFFERED", "x") + + return env + def _get_app_log_dir(self, app_id: str, cfg: LocalOpts) -> str: """ Returns the log dir. We redirect stdout/err