Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for parallel syncing #85

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/wandb_osh/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def _get_parser() -> ArgumentParser:
type=int,
help="Timeout for wandb sync. If <=0, no timeout.",
)
parser.add_argument(
"--num-workers",
default=1,
type=int,
help="Number of parallel syncs to run at a time.",
)
parser.add_argument(
"wandb_options",
nargs="*",
Expand All @@ -44,9 +50,12 @@ def main(argv=None) -> None:
parser = _get_parser()
args = parser.parse_args(argv)
wandb_osh = WandbSyncer(
command_dir=args.command_dir, wait=args.wait, wandb_options=args.wandb_options
command_dir=args.command_dir,
wait=args.wait,
wandb_options=args.wandb_options,
num_workers=args.num_workers,
)
wandb_osh.loop()
wandb_osh.start()


if __name__ == "__main__":
Expand Down
60 changes: 40 additions & 20 deletions src/wandb_osh/syncer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import os
import subprocess
import time
from multiprocessing import Process, Queue
from os import PathLike
from pathlib import Path
from queue import Empty

from wandb_osh import __version__
from wandb_osh.config import _command_dir_default
Expand All @@ -19,6 +21,7 @@ def __init__(
wandb_options: list[str] | None = None,
*,
timeout: int | float = 120,
num_workers: int = 1,
):
"""Class for interpreting command files and triggering
`wandb sync`.
Expand All @@ -35,6 +38,25 @@ def __init__(
self.wait = wait
self.wandb_options = wandb_options
self._timeout = timeout
self.num_workers = num_workers
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's mark these three attributes private and also the new methods

self.target_queue: Queue = Queue()
self.workers: list[Process] = []

def start(self) -> None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though the new name makes more sense, this would break backward compatible, so probably best to keep this as "loop".

"""Start directory watcher process and sync workers
Args:
None
"""
watcher = Process(target=self.dir_watcher)
watcher.start()

self.command_dir.mkdir(parents=True, exist_ok=True)

for _ in range(self.num_workers):
p = Process(target=self.worker)
self.workers.append(p)
p.start()

def sync(self, dir: PathLike) -> None:
"""Sync a directory. Thin wrapper around the `sync_dir` function.
Expand All @@ -44,44 +66,42 @@ def sync(self, dir: PathLike) -> None:
"""
sync_dir(dir, options=self.wandb_options, timeout=self._timeout)

def loop(self) -> None:
def dir_watcher(self) -> None:
"""Read command files and trigger syncing"""
logger.info(
"wandb-osh v%s, starting to watch %s", __version__, self.command_dir
)
while True:
start_time = time.time()
self.command_dir.mkdir(parents=True, exist_ok=True)
command_files = []
targets = []
for command_file in self.command_dir.glob("*.command"):
target = Path(command_file.read_text())
command_files.append(command_file)
if not target.is_dir():
logger.error(
"Command file %s points to non-existing directory %s",
command_file,
target,
)
continue
targets.append(target)
for target in set(targets):
logger.info("Syncing %s...", target)
try:
self.sync(target)
except subprocess.TimeoutExpired:
# try again later
logger.warning("Syncing %s timed out. Trying later.", target)
from wandb_osh.hooks import TriggerWandbSyncHook

TriggerWandbSyncHook(self.command_dir)(target)
time.sleep(0.25)
for cf in command_files:
self.target_queue.put((command_file, target))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous version, we were making sure that we don't have the same synchronization target more than once in order to avoid one process outcrowding the others (as you reported in #83 , that wasn't enough though). I think we should still do this, i.e., we check if the target is already in the queue (if it is, we simply remove the command_file right here and don't add anything to the queue).

time.sleep(max(0.0, (time.time() - start_time) - self.wait))

def worker(self) -> None:
while True:
try:
cf, target = self.target_queue.get(timeout=self._timeout)
self.sync(target)
time.sleep(0.25)
if cf.is_file():
cf.unlink()
if "PYTEST_CURRENT_TEST" in os.environ:
break
time.sleep(max(0.0, (time.time() - start_time) - self.wait))
if "PYTEST_CURRENT_TEST" in os.environ:
break
except Empty:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be for the case that we haven't encountered any new synchronization requests. So this should be after target_queue.get and simply be ignored (with a small timeout), right? Because it's nothing bad.

The "try again later" code below is if self.sync(target) times out

# try again later
logger.warning("Syncing %s timed out. Trying later.", target)
from wandb_osh.hooks import TriggerWandbSyncHook

TriggerWandbSyncHook(self.command_dir)(target)


def sync_dir(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_syncer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def test_wandb_syncer(tmp_path, caplog):
target = tmp_path / "test" / "123"
(tmp_path / "123.command").write_text(str(target.resolve()))
with caplog.at_level(logging.WARNING):
ws.loop()
ws.start()
assert "points to non-existing directory" in caplog.text
caplog.clear()
(tmp_path / "123.command").write_text(str(target.resolve()))
target.mkdir(parents=True)
with caplog.at_level(logging.DEBUG):
ws.loop()
ws.start()
assert f"Command would be: wandb sync . in {target.resolve()}" in caplog.text
set_log_level()

Expand All @@ -38,5 +38,5 @@ def test_wandb_sync_timeout(tmp_path, caplog):
(tmp_path / "123.command").write_text(str(target.resolve()))
target.mkdir(parents=True)
with caplog.at_level(logging.DEBUG):
ws.loop()
ws.start()
assert "timed out. Trying later." in caplog.text