Skip to content

Commit

Permalink
Move logic from TorchX CLI -> API, so MVAI can call it (#955)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #955

MVAI's "light" is synchronous - you can immediately see the logs for jobs you start. Only "fire" is asynchronous.

TorchX's API, since it's generic, *always* creates jobs that are asynchronous. Therefore, there isn't a built-in interface for "tailing" the stderr of every started process - just for tailing individual replicas of a given role.

The TorchX CLI's `torchx run` command **has** implemented this, but its implementation is coupled with the CLI implementations of `torchx run` and `torchx log`.

This diff extracts the useful logic into a helper function of the TorchX API

Reviewed By: andywag

Differential Revision: D62463211
  • Loading branch information
Julie Ganeshan authored and facebook-github-bot committed Sep 11, 2024
1 parent ce17fbb commit 0bfee30
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 35 deletions.
28 changes: 4 additions & 24 deletions torchx/cli/cmd_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from torchx.schedulers.api import Stream
from torchx.specs.api import is_started
from torchx.specs.builders import make_app_handle
from torchx.util.log_tee_helpers import (
_find_role_replicas as find_role_replicas,
_prefix_line,
)

from torchx.util.types import none_throws

Expand All @@ -39,19 +43,6 @@ def validate(job_identifier: str) -> None:
sys.exit(1)


def _prefix_line(prefix: str, line: str) -> str:
"""
_prefix_line ensure the prefix is still present even when dealing with return characters
"""
if "\r" in line:
line = line.replace("\r", f"\r{prefix}")
if "\n" in line[:-1]:
line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:]
if not line.startswith("\r"):
line = f"{prefix}{line}"
return line


def print_log_lines(
file: TextIO,
runner: Runner,
Expand Down Expand Up @@ -167,17 +158,6 @@ def get_logs(
raise threads_exceptions[0]


def find_role_replicas(
app: specs.AppDef, role_name: Optional[str]
) -> List[Tuple[str, int]]:
role_replicas = []
for role in app.roles:
if role_name is None or role_name == role.name:
for i in range(role.num_replicas):
role_replicas.append((role.name, i))
return role_replicas


class CmdLog(SubCommand):
def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
subparser.add_argument(
Expand Down
20 changes: 9 additions & 11 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torchx.specs as specs
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run
from torchx.cli.cmd_base import SubCommand
from torchx.cli.cmd_log import get_logs
from torchx.runner import config, get_runner, Runner
from torchx.runner.config import load_sections
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
Expand All @@ -32,6 +31,7 @@
get_builtin_source,
get_components,
)
from torchx.util.log_tee_helpers import tee_logs
from torchx.util.types import none_throws


Expand Down Expand Up @@ -288,16 +288,14 @@ def _wait_and_exit(self, runner: Runner, app_handle: str, log: bool) -> None:
logger.debug(status)

def _start_log_thread(self, runner: Runner, app_handle: str) -> threading.Thread:
thread = threading.Thread(
target=get_logs,
kwargs={
"file": sys.stderr,
"runner": runner,
"identifier": app_handle,
"regex": None,
"should_tail": True,
},
thread = tee_logs(
dst=sys.stderr,
app_handle=app_handle,
regex=None,
runner=runner,
should_tail=True,
streams=None,
colorize=not sys.stderr.closed and sys.stderr.isatty(),
)
thread.daemon = True
thread.start()
return thread
210 changes: 210 additions & 0 deletions torchx/util/log_tee_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
If you're wrapping the TorchX API with your own CLI, these functions can
help show the logs of the job within your CLI, just like
`torchx log`
"""

import logging
import threading
from queue import Queue
from typing import List, Optional, TextIO, Tuple, TYPE_CHECKING

from torchx.util.types import none_throws

if TYPE_CHECKING:
from torchx.runner.api import Runner
from torchx.schedulers.api import Stream
from torchx.specs.api import AppDef

logger: logging.Logger = logging.getLogger(__name__)

# A torchX job can have stderr/stdout for many replicas, of many roles
# The scheduler API has functions that allow us to get,
# with unspecified detail, the log lines of a given replica of
# a given role.
#
# So, to neatly tee the results, we:
# 1) Determine every role ID / replica ID pair we want to monitor
# 2) Request the given stderr / stdout / combined streams from them (1 thread each)
# 3) Concatenate each of those streams to a given destination file


def _find_role_replicas(
app: "AppDef",
role_name: Optional[str],
) -> List[Tuple[str, int]]:
"""
Enumerate all (role, replica id) pairs in the given AppDef.
Replica IDs are 0-indexed, and range up to num_replicas,
for each role.
If role_name is provided, filters to only that name.
"""
role_replicas = []
for role in app.roles:
if role_name is None or role_name == role.name:
for i in range(role.num_replicas):
role_replicas.append((role.name, i))
return role_replicas


def _prefix_line(prefix: str, line: str) -> str:
"""
_prefix_line ensure the prefix is still present even when dealing with return characters
"""
if "\r" in line:
line = line.replace("\r", f"\r{prefix}")
if "\n" in line[:-1]:
line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:]
if not line.startswith("\r"):
line = f"{prefix}{line}"
return line


def _print_log_lines_for_role_replica(
dst: TextIO,
app_handle: str,
regex: Optional[str],
runner: "Runner",
which_role: str,
which_replica: int,
exceptions: "Queue[Exception]",
should_tail: bool,
streams: Optional["Stream"],
colorize: bool = False,
) -> None:
"""
Helper function that'll run in parallel - one
per monitored replica of a given role.
Based on print_log_lines .. but not designed for TTY
"""
try:
for line in runner.log_lines(
app_handle,
which_role,
which_replica,
regex,
should_tail=should_tail,
streams=streams,
):
if colorize:
color_begin = "\033[32m"
color_end = "\033[0m"
else:
color_begin = ""
color_end = ""
prefix = f"{color_begin}{which_role}/{which_replica}{color_end} "
print(_prefix_line(prefix, line), file=dst, end="", flush=True)
except Exception as e:
exceptions.put(e)
raise


def _start_threads_to_monitor_role_replicas(
dst: TextIO,
app_handle: str,
regex: Optional[str],
runner: "Runner",
which_role: Optional[str] = None,
should_tail: bool = False,
streams: Optional["Stream"] = None,
colorize: bool = False,
) -> None:
threads = []

app = none_throws(runner.describe(app_handle))
replica_ids = _find_role_replicas(app, role_name=which_role)

# Holds exceptions raised by all threads, in a thread-safe
# object
exceptions = Queue()

if not replica_ids:
valid_roles = [role.name for role in app.roles]
raise ValueError(
f"{which_role} is not a valid role name. Available: {valid_roles}"
)

for role_name, replica_id in replica_ids:
threads.append(
threading.Thread(
target=_print_log_lines_for_role_replica,
kwargs={
"dst": dst,
"runner": runner,
"app_handle": app_handle,
"which_role": role_name,
"which_replica": replica_id,
"regex": regex,
"should_tail": should_tail,
"exceptions": exceptions,
"streams": streams,
"colorize": colorize,
},
daemon=True,
)
)

for t in threads:
t.start()

for t in threads:
t.join()

# Retrieve all exceptions, print all except one and raise the first recorded exception
threads_exceptions = []
while not exceptions.empty():
threads_exceptions.append(exceptions.get())

if len(threads_exceptions) > 0:
for i in range(1, len(threads_exceptions)):
logger.error(threads_exceptions[i])

raise threads_exceptions[0]


def tee_logs(
dst: TextIO,
app_handle: str,
regex: Optional[str],
runner: "Runner",
should_tail: bool = False,
streams: Optional["Stream"] = None,
colorize: bool = False,
) -> threading.Thread:
"""
Makes a thread, which in turn will start 1 thread per replica
per role, that tees that role-replica's logs to the given
destination buffer.
You'll need to start and join with this parent thread.
dst: TextIO to tee the logs into
app_handle: The return value of runner.run() or runner.schedule()
regex: Regex to filter the logs that are tee-d
runner: The Runner you used to schedule the job
should_tail: If true, continue until we run out of logs. Otherwise, just fetch
what's available
streams: Whether to fetch STDERR, STDOUT, or the temporally COMBINED (default) logs
"""
thread = threading.Thread(
target=_start_threads_to_monitor_role_replicas,
kwargs={
"dst": dst,
"runner": runner,
"app_handle": app_handle,
"regex": None,
"should_tail": True,
"colorize": colorize,
},
daemon=True,
)
return thread

0 comments on commit 0bfee30

Please sign in to comment.