From 51c8c3c1b5cedd5b62061637aa46588f30840470 Mon Sep 17 00:00:00 2001 From: George Muraru Date: Wed, 9 Dec 2020 01:38:06 +0200 Subject: [PATCH 1/3] Add generic multie launcher --- examples/mpc_linear_svm/launcher.py | 4 +- scripts/aws_launcher.py | 396 ------------------ scripts/multiple_machines/aws_launcher.py | 199 +++++++++ scripts/multiple_machines/common.py | 288 +++++++++++++ scripts/multiple_machines/generic_launcher.py | 80 ++++ 5 files changed, 569 insertions(+), 398 deletions(-) delete mode 100644 scripts/aws_launcher.py create mode 100644 scripts/multiple_machines/aws_launcher.py create mode 100644 scripts/multiple_machines/common.py create mode 100644 scripts/multiple_machines/generic_launcher.py diff --git a/examples/mpc_linear_svm/launcher.py b/examples/mpc_linear_svm/launcher.py index d88a1d33..56903d57 100644 --- a/examples/mpc_linear_svm/launcher.py +++ b/examples/mpc_linear_svm/launcher.py @@ -23,8 +23,6 @@ import logging import os -from examples.multiprocess_launcher import MultiProcessLauncher - parser = argparse.ArgumentParser(description="CrypTen Linear SVM Training") parser.add_argument( @@ -82,6 +80,8 @@ def _run_experiment(args): def main(run_experiment): args = parser.parse_args() if args.multiprocess: + from examples.multiprocess_launcher import MultiProcessLauncher + launcher = MultiProcessLauncher(args.world_size, run_experiment, args) launcher.start() launcher.join() diff --git a/scripts/aws_launcher.py b/scripts/aws_launcher.py deleted file mode 100644 index 815375ce..00000000 --- a/scripts/aws_launcher.py +++ /dev/null @@ -1,396 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -""" -This file is a tool to run MPC distributed training over AWS. - -To run distributed training, first multiple AWS instances needs to be created -with a public AMI "Deep Learning AMI (Ubuntu) Version 24.0": - -$ aws ec2 run-instances \ - --image-id ami-0ddba16a97b1dcda5 \ - --count 2 \ - --instance-type t2.micro \ - --key-name fair-$USER \ - --tag-specifications "ResourceType=instance,Tags=[{Key=fair-user,Value=$USER}]" - -Two EC2 instances will be created by the command line shown above. Assume -the ids of the two instances created are i-068681e808235a851 and -i-0d7ebacfe1e3f28eb. Next, pytorch and crypten must be properly installed -on every instance. - -Then the following command lines can run the mpc_linear_svm example on the two -EC2 instances created above: - -$ python3 crypten/scripts/aws_launcher.py \ - --SSH_keys=/home/$USER/.aws/fair-$USER.pem \ - --instances=i-038dd14b9383b9d79,i-08f057b9c03d4a916 \ - --aux_files=crypten/examples/mpc_linear_svm/mpc_linear_svm.py \ - crypten/examples/mpc_linear_svm/launcher.py \ - --features 50 \ - --examples 100 \ - --epochs 50 \ - --lr 0.5 \ - --skip_plaintext - - -If you want to train with AWS instances located at multiple regions, then you would need -to provide ssh_key_file for each instance: - -$ python3 crypten/scripts/aws_launcher.py \ - --regions=us-east-1,us-west-1 \ - --SSH_keys=/home/$USER/.aws/east.pem,/home/$USER/.aws/west.pem \ - --instances=i-038dd14b9383b9d79,i-08f057b9c03d4a916 \ - --aux_files=crypten/examples/mpc_linear_svm/mpc_linear_svm.py \ - crypten/examples/mpc_linear_svm/launcher.py \ - --features 50 \ - --examples 100 \ - --epochs 50 \ - --lr 0.5 \ - --skip_plaintext - -""" - -import concurrent.futures -import configparser -import os -import sys -import time -import uuid -import warnings -from argparse import REMAINDER, ArgumentParser -from pathlib import Path - -import boto3 -import paramiko - - -def get_instances(ec2, instance_ids): - instances = list( - ec2.instances.filter(Filters=[{"Name": "instance-id", "Values": instance_ids}]) - ) - return instances - - -def connect_to_instance(instance, keypath, username, http_proxy=None): - print(f"Connecting to {instance.id}...") - - ip_address = instance.public_ip_address - if http_proxy: - # paramiko.ProxyCommand does not do string substitution for %h %p, - # so 'nc --proxy-type http --proxy fwdproxy:8080 %h %p' would not work! - proxy = paramiko.ProxyCommand( - f"nc --proxy-type http --proxy {http_proxy} {ip_address} {22}" - ) - proxy.settimeout(300) - client = paramiko.SSHClient() - client.load_system_host_keys() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - retries = 20 - while retries > 0: - try: - client.connect( - ip_address, - username=username, - key_filename=keypath, - timeout=10, - sock=proxy if http_proxy else None, - ) - print(f"Connected to {instance.id}") - break - except Exception as e: - print(f"Exception: {e} Retrying...") - retries -= 1 - time.sleep(10) - return client - - -def add_prefix_each_line(prefix, str): - lines = [f"{prefix}{line}" for line in str.split("\n")] - return "\n".join(lines) - - -def run_command(instance, client, cmd, environment=None, inputs=None): - stdin, stdout, stderr = client.exec_command( - cmd, get_pty=True, environment=environment - ) - if inputs: - for inp in inputs: - stdin.write(inp) - - def read_lines(fin, fout, line_head): - line = "" - while not fin.channel.exit_status_ready(): - line += fin.read(1).decode("utf8") - if line.endswith("\n"): - print(f"{line_head}{line[:-1]}", file=fout) - line = "" - if line: - # print what remains in line buffer, in case fout does not - # end with '\n' - print(f"{line_head}{line[:-1]}", file=fout) - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as printer: - printer.submit(read_lines, stdout, sys.stdout, f"[{instance} STDOUT] ") - printer.submit(read_lines, stderr, sys.stderr, f"[{instance} STDERR] ") - - -def upload_file(instance_id, client, localpath, remotepath): - ftp_client = client.open_sftp() - print(f"Uploading `{localpath}` to {instance_id}...") - ftp_client.put(localpath, remotepath) - ftp_client.close() - print(f"`{localpath}` uploaded to {instance_id}.") - - -def main(): - args = parse_args() - - cf = configparser.ConfigParser() - cf.read(args.credentials) - - warnings.filterwarnings( - "ignore", category=ResourceWarning, message="unclosed.*" - ) - - regions = args.regions.split(",") - instance_ids = args.instances.split(",") - ssh_key_files = args.ssh_key_file.split(",") - - instances = [] - if len(regions) > 1: - print("Multiple regions detected") - - assert len(instance_ids) == len( - ssh_key_files - ), "{} instance ids are provided, but {} SSH keys found.".format( - len(instance_ids), len(ssh_key_files) - ) - - assert len(instance_ids) == len( - regions - ), "{} instance ids are provided, but {} regions found.".format( - len(instance_ids), len(regions) - ) - - for i, region in enumerate(regions): - session = boto3.session.Session( - aws_access_key_id=cf["default"]["aws_access_key_id"], - aws_secret_access_key=cf["default"]["aws_secret_access_key"], - region_name=region, - ) - ec2 = session.resource("ec2") - - instance = get_instances(ec2, [instance_ids[i]]) - instances += instance - else: - session = boto3.session.Session( - aws_access_key_id=cf["default"]["aws_access_key_id"], - aws_secret_access_key=cf["default"]["aws_secret_access_key"], - region_name=regions[0], - ) - ec2 = session.resource("ec2") - instances = get_instances(ec2, instance_ids) - - assert ( - len(ssh_key_files) == 1 - ), "1 region is detected, but {} SSH keys found.".format(len(ssh_key_files)) - - ssh_key_files = [ssh_key_files[0] for _ in range(len(instances))] - - assert len(instance_ids) == len( - instances - ), "{} instance ids are provided, but {} found.".format( - len(instance_ids), len(instances) - ) - - # Only print the public IP addresses of the instances. - # Then do nothing else and return. - if args.only_show_instance_ips: - for instance in instances: - print(instance.public_ip_address) - return - - world_size = len(instances) - print(f"Running world size {world_size} with instances: {instances}") - master_instance = instances[0] - - # Key: instance id; value: paramiko.SSHClient object. - client_dict = {} - for i, instance in enumerate(instances): - client = connect_to_instance( - instance, ssh_key_files[i], args.ssh_user, args.http_proxy - ) - client_dict[instance.id] = client - - assert os.path.exists( - args.training_script - ), f"File `{args.training_script}` does not exist" - file_paths = args.aux_files.split(",") if args.aux_files else [] - for local_path in file_paths: - assert os.path.exists(local_path), f"File `{local_path}` does not exist" - - remote_dir = f"aws-launcher-tmp-{uuid.uuid1()}" - script_basename = os.path.basename(args.training_script) - remote_script = os.path.join(remote_dir, script_basename) - - # Upload files to all instances concurrently. - with concurrent.futures.ThreadPoolExecutor(max_workers=8) as uploaders: - for instance_id, client in client_dict.items(): - run_command(instance_id, client, f"mkdir -p {remote_dir}") - uploaders.submit( - upload_file, instance_id, client, args.training_script, remote_script - ) - for local_path in file_paths: - uploaders.submit( - upload_file, - instance_id, - client, - local_path, - os.path.join(remote_dir, os.path.basename(local_path)), - ) - for instance_id, client in client_dict.items(): - run_command(instance_id, client, f"chmod +x {remote_script}") - run_command(instance_id, client, f"ls -al {remote_dir}") - - environment = { - "WORLD_SIZE": str(world_size), - "RENDEZVOUS": "env://", - "MASTER_ADDR": master_instance.private_ip_address, - "MASTER_PORT": str(args.master_port), - } - - with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as executor: - rank = 0 - for instance_id, client in client_dict.items(): - environment["RANK"] = str(rank) - # TODO: Although paramiko.SSHClient.exec_command() can accept - # an argument `environment`, it seems not to take effect in - # practice. It might because "Servers may silently reject - # some environment variables" according to paramiko document. - # As a workaround, here all environment variables are explicitly - # exported. - environment_cmd = "; ".join( - [f"export {key}={value}" for (key, value) in environment.items()] - ) - prepare_cmd = f"{args.prepare_cmd}; " if args.prepare_cmd else "" - cmd = "{}; {} {} {} {}".format( - environment_cmd, - f"cd {remote_dir} ;", - prepare_cmd, - f"./{script_basename}", - " ".join(args.training_script_args), - ) - print(f"Run command: {cmd}") - executor.submit(run_command, instance_id, client, cmd, environment) - rank += 1 - - # Cleanup temp dir. - for instance_id, client in client_dict.items(): - run_command(instance_id, client, f"rm -rf {remote_dir}") - client.close() - - -def parse_args(): - """ - Helper function parsing the command line options - """ - parser = ArgumentParser( - description="PyTorch distributed training launch " - "helper utilty that will spawn up " - "parties for MPC scripts on AWS" - ) - - parser.add_argument( - "--credentials", - type=str, - default=f"{Path.home()}/.aws/credentials", - help="Credentials used to access AWS", - ) - - parser.add_argument( - "--only_show_instance_ips", - action="store_true", - default=False, - help="Only show public IPs of the given instances." - "No other actions will be done", - ) - - parser.add_argument("--regions", type=str, default="us-west-2", help="AWS Region") - - parser.add_argument( - "--instances", - type=str, - required=True, - help="The comma-separated ids of AWS instances", - ) - - parser.add_argument( - "--master_port", - type=int, - default=29500, - help="The port used by master instance " "for distributed training", - ) - - parser.add_argument( - "--ssh_key_file", - type=str, - required=True, - help="Path to the RSA private key file " "used for instance authentication", - ) - - parser.add_argument( - "--ssh_user", - type=str, - default="ubuntu", - help="The username to ssh to AWS instance", - ) - - parser.add_argument( - "--http_proxy", - type=str, - default=None, - help="If not none, use the http proxy specified " - "(e.g., fwdproxy:8080) to ssh to AWS instance", - ) - - parser.add_argument( - "--aux_files", - type=str, - default=None, - help="The comma-separated paths of additional files " - " that need to be transferred to AWS instances. " - "If more than one file needs to be transferred, " - "the basename of any two files can not be the " - "same.", - ) - - parser.add_argument( - "--prepare_cmd", - type=str, - default="", - help="The command to run before running distribute " - "training for prepare purpose, e.g., setup " - "environment, extract data files, etc.", - ) - - # positional - parser.add_argument( - "training_script", - type=str, - help="The full path to the single machine training " - "program/script to be launched in parallel, " - "followed by all the arguments for the " - "training script", - ) - - # rest from the training program - parser.add_argument("training_script_args", nargs=REMAINDER) - return parser.parse_args() - - -if __name__ == "__main__": - main() diff --git a/scripts/multiple_machines/aws_launcher.py b/scripts/multiple_machines/aws_launcher.py new file mode 100644 index 00000000..6536d27b --- /dev/null +++ b/scripts/multiple_machines/aws_launcher.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +This file is a tool to run MPC distributed training over AWS. + +To run distributed training, first multiple AWS instances needs to be created +with a public AMI "Deep Learning AMI (Ubuntu) Version 24.0": + +$ aws ec2 run-instances \ + --image-id ami-0ddba16a97b1dcda5 \ + --count 2 \ + --instance-type t2.micro \ + --key-name fair-$USER \ + --tag-specifications "ResourceType=instance,Tags=[{Key=fair-user,Value=$USER}]" + +Two EC2 instances will be created by the command line shown above. Assume +the ids of the two instances created are i-068681e808235a851 and +i-0d7ebacfe1e3f28eb. Next, pytorch and crypten must be properly installed +on every instance. + +Then the following command lines can run the mpc_linear_svm example on the two +EC2 instances created above: + +$ python3 crypten/scripts/aws_launcher.py \ + --SSH_keys=/home/$USER/.aws/fair-$USER.pem \ + --instances=i-038dd14b9383b9d79,i-08f057b9c03d4a916 \ + --aux_files=crypten/examples/mpc_linear_svm/mpc_linear_svm.py \ + crypten/examples/mpc_linear_svm/launcher.py \ + --features 50 \ + --examples 100 \ + --epochs 50 \ + --lr 0.5 \ + --skip_plaintext + + +If you want to train with AWS instances located at multiple regions, then you would need +to provide ssh_key_file for each instance: + +$ python3 crypten/scripts/aws_launcher.py \ + --regions=us-east-1,us-west-1 \ + --SSH_keys=/home/$USER/.aws/east.pem,/home/$USER/.aws/west.pem \ + --instances=i-038dd14b9383b9d79,i-08f057b9c03d4a916 \ + --aux_files=crypten/examples/mpc_linear_svm/mpc_linear_svm.py \ + crypten/examples/mpc_linear_svm/launcher.py \ + --features 50 \ + --examples 100 \ + --epochs 50 \ + --lr 0.5 \ + --skip_plaintext + +""" + +import configparser +import os +import sys +import time +import warnings +from argparse import REMAINDER, ArgumentParser +from pathlib import Path + +import boto3 +import paramiko + +import common + +def get_instances(ec2, instance_ids): + instances = list( + ec2.instances.filter(Filters=[{"Name": "instance-id", "Values": instance_ids}]) + ) + return instances + + +def main(): + parser = get_parser() + args = parser.parse_args() + + cf = configparser.ConfigParser() + cf.read(args.credentials) + + warnings.filterwarnings( + "ignore", category=ResourceWarning, message="unclosed.*" + ) + + regions = args.regions.split(",") + instance_ids = args.instances.split(",") + ssh_key_files = args.ssh_key_file.split(",") + + instances = [] + + # AWS Specific + if len(regions) > 1: + print("Multiple regions detected") + + assert len(instance_ids) == len( + ssh_key_files + ), "{} instance ids are provided, but {} SSH keys found.".format( + len(instance_ids), len(ssh_key_files) + ) + + assert len(instance_ids) == len( + regions + ), "{} instance ids are provided, but {} regions found.".format( + len(instance_ids), len(regions) + ) + + for i, region in enumerate(regions): + session = boto3.session.Session( + aws_access_key_id=cf["default"]["aws_access_key_id"], + aws_secret_access_key=cf["default"]["aws_secret_access_key"], + region_name=region, + ) + ec2 = session.resource("ec2") + + instance = get_instances(ec2, [instance_ids[i]]) + instances += instance + else: + session = boto3.session.Session( + aws_access_key_id=cf["default"]["aws_access_key_id"], + aws_secret_access_key=cf["default"]["aws_secret_access_key"], + region_name=regions[0], + ) + ec2 = session.resource("ec2") + instances = get_instances(ec2, instance_ids) + + assert ( + len(ssh_key_files) == 1 + ), "1 region is detected, but {} SSH keys found.".format(len(ssh_key_files)) + + ssh_key_files = [ssh_key_files[0] for _ in range(len(instances))] + + assert len(instance_ids) == len( + instances + ), "{} instance ids are provided, but {} found.".format( + len(instance_ids), len(instances) + ) + + # Only print the public IP addresses of the instances. + # Then do nothing else and return. + if args.only_show_instance_ips: + for instance in instances: + print(instance.public_ip_address) + return + + world_size = len(instances) + print(f"Running world size {world_size} with instances: {instances}") + master_instance = instances[0] + + # Key: instance id; value: paramiko.SSHClient object. + client_dict = {} + for ssh_key_file, instance in zip(ssh_key_files, instances): + client = common.connect_to_machine( + instance.public_ip_address, ssh_key_file, args.ssh_user, args.http_proxy + ) + client_dict[instance.public_ip_address] = client + + remote_dir, script_basename = common.upload_files_to_machines(client_dict, ars.file_paths, args.script) + + environment = { + "WORLD_SIZE": str(world_size), + "RENDEZVOUS": "env://", + "MASTER_ADDR": master_instance.private_ip_address, + "MASTER_PORT": str(args.master_port), + } + + common.run_script_parallel(environment) + + common.cleanup(client_dict, remote_dir) + + +def get_parser(): + parser = common.get_parser() + + """ Add AWS specific arguments """ + parser.add_argument( + "--only_show_instance_ips", + action="store_true", + default=False, + help="Only show public IPs of the given instances." + "No other actions will be done", + ) + + parser.add_argument("--regions", type=str, default="us-west-2", help="AWS Region") + + parser.add_argument( + "--instances", + type=str, + required=True, + help="The comma-separated ids of AWS instances", + ) + + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/multiple_machines/common.py b/scripts/multiple_machines/common.py new file mode 100644 index 00000000..31bd0408 --- /dev/null +++ b/scripts/multiple_machines/common.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +This file contains helper function to run MPC over multiple machines + +""" + +import concurrent.futures +import configparser +import os +import sys +import time +import uuid +import warnings +from argparse import REMAINDER, ArgumentParser +from pathlib import Path + +import boto3 +import paramiko + + +def run_command(machine_ip, client, cmd, environment=None, inputs=None): + print(machine_ip, client) + stdin, stdout, stderr = client.exec_command( + cmd, get_pty=True, environment=environment + ) + if inputs: + for inp in inputs: + stdin.write(inp) + + def read_lines(fin, fout, line_head): + line = "" + while not fin.channel.exit_status_ready(): + line += fin.read(1).decode("utf8") + if line.endswith("\n"): + print(f"{line_head}{line[:-1]}", file=fout) + line = "" + if line: + # print what remains in line buffer, in case fout does not + # end with '\n' + print(f"{line_head}{line[:-1]}", file=fout) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as printer: + printer.submit(read_lines, stdout, sys.stdout, f"[{machine_ip} STDOUT] ") + printer.submit(read_lines, stderr, sys.stderr, f"[{machine_ip} STDERR] ") + + +def upload_files_to_machines(client_dict, aux_files, script): + path_script = Path(script) + + assert path_script.exists(), f"File `{script}` does not exist" + file_paths = aux_files.split(",") if aux_files else [] + + for local_path in file_paths: + assert Path(local_path).exists(), f"File `{local_path}` does not exist" + + remote_dir = Path(f"tmp-dir-{uuid.uuid1()}") + script_basename = path_script.name + remote_script = remote_dir / script_basename + + print(f"Remote path {remote_dir}") + + def upload_file(machine_ip, client, localpath, remotepath): + ftp_client = client.open_sftp() + print(f"Uploading `{localpath}` to {machine_ip} as {remotepath}...") + ftp_client.put(localpath, str(remotepath), confirm=True) + ftp_client.close() + print(f"`{localpath}` uploaded to {machine_ip}") + + # Upload files to all machines concurrently. + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as uploaders: + for machine_ip, client in client_dict.items(): + run_command(machine_ip, client, f"mkdir {remote_dir}") + uploaders.submit( + upload_file, machine_ip, client, script, remote_script + ) + for local_path in file_paths: + local_path = Path(local_path) + uploaders.submit( + upload_file, + machine_ip, + client, + local_path, + remote_dir / local_path.name, + ) + + for machine_ip, client in client_dict.items(): + run_command(machine_ip, client, f"chmod +x {remote_script}") + run_command(machine_ip, client, f"ls -al {remote_dir}") + + return remote_dir, script_basename + +def run_script_parallel( + environment, + client_dict, + remote_dir, + script_basename, + script_args=None, + prepare_cmd=None + ): + + world_size = len(client_dict) + with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as executor: + for rank, (machine_ip, client) in enumerate(client_dict.items()): + environment["RANK"] = str(rank) + # TODO: Although paramiko.SSHClient.exec_command() can accept + # an argument `environment`, it seems not to take effect in + # practice. It might because "Servers may silently reject + # some environment variables" according to paramiko document. + # As a workaround, here all environment variables are explicitly + # exported. + environment_cmd = "; ".join( + [f"export {key}={value}" for (key, value) in environment.items()] + ) + prep_cmd = f"{prepare_cmd}; " if prepare_cmd else "" + if rank == 0: + environment_cmd += ";export GLOO_SOCKET_IFNAME=wlp59s0" + else: + environment_cmd += ";export GLOO_SOCKET_IFNAME=wlp2s0" + cmd = "{}; {} {} {} {}".format( + environment_cmd, + f"cd {remote_dir} ;", + prep_cmd, + f"./{script_basename}", + " ".join(script_args), + ) + print(f"Run command: {cmd}") + executor.submit(run_command, machine_ip, client, cmd, environment) + + +def connect_to_machine(ip_address, keypath, username, http_proxy=None, retries=20): + print(f"Connecting to {ip_address}...") + + proxy = None + if http_proxy: + # paramiko.ProxyCommand does not do string substitution for %h %p, + # so 'nc --proxy-type http --proxy fwdproxy:8080 %h %p' would not work! + proxy = paramiko.ProxyCommand( + f"nc --proxy-type http --proxy {http_proxy} {ip_address} {22}" + ) + proxy.settimeout(300) + + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + while retries > 0: + try: + client.connect( + ip_address, + username=username, + key_filename=keypath, + timeout=10, + sock=proxy, + ) + print(f"Connected to {ip_address}") + break + except Exception as e: + print(f"Exception: {e} Retrying...") + retries -= 1 + time.sleep(10) + return client + + +def cleanup(client_dict, remote_dir): + for machine_ip, client in client_dict.items(): + run_command(machine_ip, client, f"rm -rf {remote_dir}") + client.close() + + +def run_command(machine_ip, client, cmd, environment=None, inputs=None): + stdin, stdout, stderr = client.exec_command( + cmd, get_pty=True, environment=environment + ) + if inputs: + for inp in inputs: + stdin.write(inp) + + def read_lines(fin, fout, line_head): + line = "" + while not fin.channel.exit_status_ready(): + line += fin.read(1).decode("utf8") + if line.endswith("\n"): + print(f"{line_head}{line[:-1]}", file=fout) + line = "" + if line: + # print what remains in line buffer, in case fout does not + # end with '\n' + print(f"{line_head}{line[:-1]}", file=fout) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as printer: + printer.submit(read_lines, stdout, sys.stdout, f"[{machine_ip} STDOUT] ") + printer.submit(read_lines, stderr, sys.stderr, f"[{machine_ip} STDERR] ") + + +def get_parser(): + """ + Helper function parsing the command line options + """ + parser = ArgumentParser( + description="PyTorch distributed launcher " + "helper utilty that will spawn up " + "parties for MPC scripts on different machines" + ) + + parser.add_argument( + "--credentials", + type=str, + default=f"{Path.home()}/.aws/credentials", + help="Credentials used to access the machines", + ) + + parser.add_argument( + "--ip_addresses", + type=str, + required=True, + help="The comma-separated ip addresses for the machines", + ) + + parser.add_argument( + "--master_port", + type=int, + default=29500, + help="The port used by master instance for MPC", + ) + + parser.add_argument( + "--ssh_key_file", + type=str, + required=True, + help="Path to the RSA private key file used for authentication", + ) + + parser.add_argument( + "--ssh_user", + type=str, + default="ubuntu", + help="The username to ssh to the other machines" + ) + + parser.add_argument( + "--http_proxy", + type=str, + default=None, + help="If not none, use the http proxy specified " + "(e.g., fwdproxy:8080) to ssh to machines", + ) + + parser.add_argument( + "--aux_files", + type=str, + default=None, + help="The comma-separated paths of additional files " + " that need to be transferred to AWS instances. " + "If more than one file needs to be transferred, " + "the basename of any two files can not be the " + "same.", + ) + + parser.add_argument( + "--prepare_cmd", + type=str, + default="", + help="The command to run before running distribute " + "training for prepare purpose, e.g., setup " + "environment, extract data files, etc.", + ) + + # positional + parser.add_argument( + "script", + type=str, + help="The full path to the single " + "program/script to be launched in parallel, " + "followed by all the arguments for the script", + ) + + # the rest of the arguments are passed to the script + parser.add_argument("script_args", nargs=REMAINDER) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/multiple_machines/generic_launcher.py b/scripts/multiple_machines/generic_launcher.py new file mode 100644 index 00000000..3d413e3b --- /dev/null +++ b/scripts/multiple_machines/generic_launcher.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +This file is a tool to run MPC over multiple machines (given a +set of IP addresses) + +python3 crypten/scripts/aws_launcher.py \ + --SSH_keys=/home/$USER/ \ + --instances=192.168.0.1,192.168.0.2 \ + --aux_files=crypten/examples/mpc_linear_svm/mpc_linear_svm.py \ + crypten/examples/mpc_linear_svm/launcher.py \ + --features 50 \ + --examples 100 \ + --epochs 50 \ + --lr 0.5 \ + --skip_plaintext +""" + +import concurrent.futures +import common +import configparser +import warnings + +def main(): + parser = common.get_parser() + args = parser.parse_args() + + cf = configparser.ConfigParser() + cf.read(args.credentials) + + warnings.filterwarnings( + "ignore", category=ResourceWarning, message="unclosed.*" + ) + + ip_addresses = args.ip_addresses.split(",") + ssh_key_files = args.ssh_key_file.split(",") + + if len(ssh_key_files) == 1: + ssh_key_files = [ssh_key_files[0] for _ in range(len(ip_addresses))] + + world_size = len(ip_addresses) + print(f"Running world size {world_size} with ip_addresses: {ip_addresses}") + master_ip_address = ip_addresses[0] + + # Key: instance id; value: paramiko.SSHClient object. + client_dict = {} + for ssh_key_file, ip_address in zip(ssh_key_files, ip_addresses): + client = common.connect_to_machine( + ip_address, ssh_key_file, args.ssh_user, args.http_proxy + ) + client_dict[ip_address] = client + + remote_dir, script_basename = common.upload_files_to_machines(client_dict, args.aux_files, args.script) + + environment = { + "WORLD_SIZE": str(world_size), + "RENDEZVOUS": "env://", + "MASTER_ADDR": master_ip_address, + "MASTER_PORT": str(args.master_port), + } + + kwargs = { + "environment": environment, + "client_dict": client_dict, + "remote_dir": remote_dir, + "script_basename": script_basename, + "script_args": args.script_args, + "prepare_cmd": args.prepare_cmd + } + common.run_script_parallel(**kwargs) + + common.cleanup(client_dict, remote_dir) + + +if __name__ == "__main__": + main() From f74bb7b0446c885ad5837d8dad9086d8ae1e7a28 Mon Sep 17 00:00:00 2001 From: George Muraru Date: Wed, 9 Dec 2020 01:43:35 +0200 Subject: [PATCH 2/3] Remove debug print --- scripts/multiple_machines/common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/multiple_machines/common.py b/scripts/multiple_machines/common.py index 31bd0408..c2cd27a4 100644 --- a/scripts/multiple_machines/common.py +++ b/scripts/multiple_machines/common.py @@ -24,7 +24,6 @@ def run_command(machine_ip, client, cmd, environment=None, inputs=None): - print(machine_ip, client) stdin, stdout, stderr = client.exec_command( cmd, get_pty=True, environment=environment ) From ea6d9b5186221c58f8800231665b55c81c2558b4 Mon Sep 17 00:00:00 2001 From: George Muraru Date: Wed, 9 Dec 2020 18:14:34 +0200 Subject: [PATCH 3/3] Fix aws --- scripts/multiple_machines/aws_launcher.py | 29 +++++++---- scripts/multiple_machines/common.py | 51 +------------------ scripts/multiple_machines/generic_launcher.py | 19 ++++--- 3 files changed, 35 insertions(+), 64 deletions(-) diff --git a/scripts/multiple_machines/aws_launcher.py b/scripts/multiple_machines/aws_launcher.py index 6536d27b..1d6edfc3 100644 --- a/scripts/multiple_machines/aws_launcher.py +++ b/scripts/multiple_machines/aws_launcher.py @@ -26,7 +26,7 @@ EC2 instances created above: $ python3 crypten/scripts/aws_launcher.py \ - --SSH_keys=/home/$USER/.aws/fair-$USER.pem \ + --ssh_key_file=/home/$USER/.aws/fair-$USER.pem \ --instances=i-038dd14b9383b9d79,i-08f057b9c03d4a916 \ --aux_files=crypten/examples/mpc_linear_svm/mpc_linear_svm.py \ crypten/examples/mpc_linear_svm/launcher.py \ @@ -42,7 +42,7 @@ $ python3 crypten/scripts/aws_launcher.py \ --regions=us-east-1,us-west-1 \ - --SSH_keys=/home/$USER/.aws/east.pem,/home/$USER/.aws/west.pem \ + --ssh_key_file=/home/$USER/.aws/east.pem,/home/$USER/.aws/west.pem \ --instances=i-038dd14b9383b9d79,i-08f057b9c03d4a916 \ --aux_files=crypten/examples/mpc_linear_svm/mpc_linear_svm.py \ crypten/examples/mpc_linear_svm/launcher.py \ @@ -55,18 +55,14 @@ """ import configparser -import os -import sys -import time import warnings -from argparse import REMAINDER, ArgumentParser from pathlib import Path import boto3 -import paramiko import common + def get_instances(ec2, instance_ids): instances = list( ec2.instances.filter(Filters=[{"Name": "instance-id", "Values": instance_ids}]) @@ -157,7 +153,7 @@ def main(): ) client_dict[instance.public_ip_address] = client - remote_dir, script_basename = common.upload_files_to_machines(client_dict, ars.file_paths, args.script) + remote_dir, script_basename = common.upload_files_to_machines(client_dict, args.aux_files, args.script) environment = { "WORLD_SIZE": str(world_size), @@ -166,7 +162,15 @@ def main(): "MASTER_PORT": str(args.master_port), } - common.run_script_parallel(environment) + kwargs = { + "environment": environment, + "client_dict": client_dict, + "remote_dir": remote_dir, + "script_basename": script_basename, + "script_args": args.script_args, + "prepare_cmd": args.prepare_cmd + } + common.run_script_parallel(**kwargs) common.cleanup(client_dict, remote_dir) @@ -183,6 +187,13 @@ def get_parser(): "No other actions will be done", ) + parser.add_argument( + "--credentials", + type=str, + default=f"{Path.home()}/.aws/credentials", + help="Credentials used to access the machines", + ) + parser.add_argument("--regions", type=str, default="us-west-2", help="AWS Region") parser.add_argument( diff --git a/scripts/multiple_machines/common.py b/scripts/multiple_machines/common.py index c2cd27a4..c7d7fb84 100644 --- a/scripts/multiple_machines/common.py +++ b/scripts/multiple_machines/common.py @@ -10,16 +10,12 @@ """ import concurrent.futures -import configparser -import os import sys import time import uuid -import warnings from argparse import REMAINDER, ArgumentParser from pathlib import Path -import boto3 import paramiko @@ -93,14 +89,14 @@ def upload_file(machine_ip, client, localpath, remotepath): return remote_dir, script_basename + def run_script_parallel( environment, client_dict, remote_dir, script_basename, script_args=None, - prepare_cmd=None - ): + prepare_cmd=None): world_size = len(client_dict) with concurrent.futures.ThreadPoolExecutor(max_workers=world_size) as executor: @@ -171,31 +167,6 @@ def cleanup(client_dict, remote_dir): client.close() -def run_command(machine_ip, client, cmd, environment=None, inputs=None): - stdin, stdout, stderr = client.exec_command( - cmd, get_pty=True, environment=environment - ) - if inputs: - for inp in inputs: - stdin.write(inp) - - def read_lines(fin, fout, line_head): - line = "" - while not fin.channel.exit_status_ready(): - line += fin.read(1).decode("utf8") - if line.endswith("\n"): - print(f"{line_head}{line[:-1]}", file=fout) - line = "" - if line: - # print what remains in line buffer, in case fout does not - # end with '\n' - print(f"{line_head}{line[:-1]}", file=fout) - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as printer: - printer.submit(read_lines, stdout, sys.stdout, f"[{machine_ip} STDOUT] ") - printer.submit(read_lines, stderr, sys.stderr, f"[{machine_ip} STDERR] ") - - def get_parser(): """ Helper function parsing the command line options @@ -206,20 +177,6 @@ def get_parser(): "parties for MPC scripts on different machines" ) - parser.add_argument( - "--credentials", - type=str, - default=f"{Path.home()}/.aws/credentials", - help="Credentials used to access the machines", - ) - - parser.add_argument( - "--ip_addresses", - type=str, - required=True, - help="The comma-separated ip addresses for the machines", - ) - parser.add_argument( "--master_port", type=int, @@ -281,7 +238,3 @@ def get_parser(): # the rest of the arguments are passed to the script parser.add_argument("script_args", nargs=REMAINDER) return parser - - -if __name__ == "__main__": - main() diff --git a/scripts/multiple_machines/generic_launcher.py b/scripts/multiple_machines/generic_launcher.py index 3d413e3b..4637e204 100644 --- a/scripts/multiple_machines/generic_launcher.py +++ b/scripts/multiple_machines/generic_launcher.py @@ -20,18 +20,14 @@ --skip_plaintext """ -import concurrent.futures import common -import configparser import warnings + def main(): - parser = common.get_parser() + parser = get_parser() args = parser.parse_args() - cf = configparser.ConfigParser() - cf.read(args.credentials) - warnings.filterwarnings( "ignore", category=ResourceWarning, message="unclosed.*" ) @@ -76,5 +72,16 @@ def main(): common.cleanup(client_dict, remote_dir) +def get_parser(): + parser = common.get_parser() + parser.add_argument( + "--ip_addresses", + type=str, + required=True, + help="The comma-separated ip addresses for the machines", + ) + + return parser + if __name__ == "__main__": main()