diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 8b12ee7e..f3b1e13e 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -476,7 +476,8 @@ def main(args): import yaml metric_logger = AsyncStructuredLogger( - args.output_dir + "/training_params_and_metrics.jsonl" + args.output_dir + + f"/training_params_and_metrics_global{os.environ['RANK']}.jsonl" ) if os.environ["LOCAL_RANK"] == "0": print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m") @@ -658,7 +659,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: print(f"\033[92mRunning command: {' '.join(command)}\033[0m") process = None try: - process = StreamablePopen(command) + process = StreamablePopen( + f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log", + command, + ) except KeyboardInterrupt: print("Process interrupted by user") diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 5ac28152..ffdcd482 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -88,21 +88,23 @@ class StreamablePopen(subprocess.Popen): Provides a way of reading stdout and stderr line by line. """ - def __init__(self, *args, **kwargs): + def __init__(self, output_file, *args, **kwargs): # remove the stderr and stdout from kwargs kwargs.pop("stderr", None) kwargs.pop("stdout", None) - super().__init__(*args, **kwargs) - while True: - if self.stdout: - output = self.stdout.readline().strip() - print(output) - if self.stderr: - error = self.stderr.readline().strip() - print(error, file=sys.stderr) - if self.poll() is not None: - break + super().__init__( + *args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs + ) + with open(output_file, "wb") as full_log_file: + while True: + byte = self.stdout.read(1) + if byte: + sys.stdout.buffer.write(byte) + sys.stdout.flush() + full_log_file.write(byte) + else: + break def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000):