Skip to content

Commit

Permalink
Extract env set up/output handling in local_scheduler for easier subc…
Browse files Browse the repository at this point in the history
…lassing (#817)

Summary:
## No functional change
Extracted stdout, stderr, and env var handling logic from `local_scheduler`'s `_popen` for easier subclass function overrides.

Differential Revision: D53453899

Co-authored-by: Cheng Ni <[email protected]>
  • Loading branch information
cniii and cniii committed Feb 7, 2024
1 parent 7aabf3d commit 4a2d403
Showing 1 changed file with 54 additions and 21 deletions.
75 changes: 54 additions & 21 deletions torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,17 @@
from dataclasses import asdict, dataclass
from datetime import datetime
from types import FrameType
from typing import Any, BinaryIO, Callable, Dict, Iterable, List, Optional, TextIO
from typing import (
Any,
BinaryIO,
Callable,
Dict,
Iterable,
List,
Optional,
TextIO,
Tuple,
)

from torchx.schedulers.api import (
AppDryRunInfo,
Expand Down Expand Up @@ -658,29 +668,11 @@ def _popen(
as file name ``str`` rather than a file-like obj.
"""

stdout_ = self._get_file_io(replica_params.stdout)
stderr_ = self._get_file_io(replica_params.stderr)
combined_: Optional[Tee] = None
combined_file = self._get_file_io(replica_params.combined)
if combined_file:
combined_ = Tee(
combined_file,
none_throws(replica_params.stdout),
none_throws(replica_params.stderr),
)

# inherit parent's env vars since 99.9% of the time we want this behavior
# just make sure we override the parent's env vars with the user_defined ones
env = os.environ.copy()
env.update(replica_params.env)
# PATH is a special one, instead of overriding, append
env["PATH"] = _join_PATH(replica_params.env.get("PATH"), os.getenv("PATH"))

# default to unbuffered python for faster responsiveness locally
env.setdefault("PYTHONUNBUFFERED", "x")
stdout_, stderr_, combined_ = self._get_replica_output_handles(replica_params)

args_pfmt = pprint.pformat(asdict(replica_params), indent=2, width=80)
log.debug(f"Running {role_name} (replica {replica_id}):\n {args_pfmt}")
env = self._get_replica_env(replica_params)

proc = subprocess.Popen(
args=replica_params.args,
Expand All @@ -700,6 +692,47 @@ def _popen(
error_file=env.get("TORCHELASTIC_ERROR_FILE", "<N/A>"),
)

def _get_replica_output_handles(
self,
replica_params: ReplicaParam,
) -> Tuple[Optional[io.FileIO], Optional[io.FileIO], Optional[Tee]]:
"""
Returns the stdout, stderr, and combined outputs of the replica.
If the combined output file is not specified, then the combined output is ``None``.
"""

stdout_ = self._get_file_io(replica_params.stdout)
stderr_ = self._get_file_io(replica_params.stderr)
combined_: Optional[Tee] = None
combined_file = self._get_file_io(replica_params.combined)
if combined_file:
combined_ = Tee(
combined_file,
none_throws(replica_params.stdout),
none_throws(replica_params.stderr),
)
return stdout_, stderr_, combined_

def _get_replica_env(
self,
replica_params: ReplicaParam,
) -> Dict[str, str]:
"""
Returns environment variables for the ``_LocalReplica``
"""

# inherit parent's env vars since 99.9% of the time we want this behavior
# just make sure we override the parent's env vars with the user_defined ones
env = os.environ.copy()
env.update(replica_params.env)
# PATH is a special one, instead of overriding, append
env["PATH"] = _join_PATH(replica_params.env.get("PATH"), os.getenv("PATH"))

# default to unbuffered python for faster responsiveness locally
env.setdefault("PYTHONUNBUFFERED", "x")

return env

def _get_app_log_dir(self, app_id: str, cfg: LocalOpts) -> str:
"""
Returns the log dir. We redirect stdout/err
Expand Down

0 comments on commit 4a2d403

Please sign in to comment.