From c8dd31fe7cf7318da8929dc7561e2b81e95ec6c4 Mon Sep 17 00:00:00 2001 From: NKcqx <892670992@qq.com> Date: Mon, 26 Jun 2023 12:06:46 +0800 Subject: [PATCH 01/22] tmp save --- fed/api.py | 4 ++-- fed/barriers.py | 4 ++-- fed/config.py | 18 +++++++++++++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/fed/api.py b/fed/api.py index edb467ab..19fcacf1 100644 --- a/fed/api.py +++ b/fed/api.py @@ -210,7 +210,7 @@ def init( logger.info(f'Started rayfed with {cluster_config}') set_exit_on_failure_sending(exit_on_failure_cross_silo_sending) - recv_actor_config = fed_config.ProxyActorConfig( + recv_actor_config = fed_config.CrossSiloProxyConfig( resource_label=cross_silo_recv_resource_label) # Start recv proxy start_recv_proxy( @@ -222,7 +222,7 @@ def init( actor_config=recv_actor_config ) - send_actor_config = fed_config.ProxyActorConfig( + send_actor_config = fed_config.CrossSiloProxyConfig( resource_label=cross_silo_send_resource_label) start_send_proxy( cluster=cluster, diff --git a/fed/barriers.py b/fed/barriers.py index f1664e33..7fac9a37 100644 --- a/fed/barriers.py +++ b/fed/barriers.py @@ -378,7 +378,7 @@ def start_recv_proxy( logging_level: str, tls_config=None, retry_policy=None, - actor_config: Optional[fed_config.ProxyActorConfig] = None + actor_config: Optional[fed_config.CrossSiloProxyConfig] = None ): # Create RecevrProxyActor @@ -424,7 +424,7 @@ def start_send_proxy( tls_config: Dict = None, retry_policy=None, max_retries=None, - actor_config: Optional[fed_config.ProxyActorConfig] = None + actor_config: Optional[fed_config.CrossSiloProxyConfig] = None ): # Create SendProxyActor global _SEND_PROXY_ACTOR diff --git a/fed/config.py b/fed/config.py index f3946c16..8dabcfe8 100644 --- a/fed/config.py +++ b/fed/config.py @@ -80,7 +80,7 @@ def get_job_config(): return _job_config -class ProxyActorConfig: +class CrossSiloProxyConfig: """A class to store parameters used for Proxy Actor Attributes: @@ -89,5 +89,17 @@ class ProxyActorConfig: """ def __init__( self, - resource_label: Optional[Dict[str, str]] = None) -> None: - self.resource_label = resource_label + grpc_retry_policy: Dict = None, + send_max_retries: int = None, + timeout_in_seconds: int = 60, + messages_max_size_in_bytes: int = None, + serializing_allowed_list: Optional[Dict[str, str]] = None, + send_resource_label: Optional[Dict[str, str]] = None, + recv_resource_label: Optional[Dict[str, str]] = None) -> None: + self.grpc_retry_policy = grpc_retry_policy + self.send_max_retries = send_max_retries + self.timeout_in_seconds = timeout_in_seconds + self.messages_max_size_in_bytes = messages_max_size_in_bytes + self.serializing_allowed_list = serializing_allowed_list + self.send_resource_label = send_resource_label + self.recv_resource_label = recv_resource_label From ef842a9f75199abf466116a466fcfd3c8904166e Mon Sep 17 00:00:00 2001 From: NKcqx <892670992@qq.com> Date: Mon, 3 Jul 2023 19:21:50 +0800 Subject: [PATCH 02/22] union rpc related config into one Signed-off-by: NKcqx <892670992@qq.com> --- fed/api.py | 40 ++++++++----------- fed/barriers.py | 15 ++++--- fed/config.py | 21 ++++++++-- .../test_unpickle_with_whitelist.py | 6 ++- tests/test_exit_on_failure_sending.py | 9 ++++- tests/test_grpc_options_on_proxies.py | 5 ++- tests/test_grpc_options_per_party.py | 5 ++- tests/test_party_specific_grpc_options.py | 5 ++- tests/test_retry_policy.py | 6 ++- tests/test_setup_proxy_actor.py | 16 +++++--- 10 files changed, 82 insertions(+), 46 deletions(-) diff --git a/fed/api.py b/fed/api.py index 19fcacf1..993ab9e9 100644 --- a/fed/api.py +++ b/fed/api.py @@ -15,7 +15,7 @@ import functools import inspect import logging -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Optional import cloudpickle import ray @@ -29,6 +29,7 @@ from fed._private.global_context import get_global_context, clear_global_context from fed.barriers import ping_others, recv, send, start_recv_proxy, start_send_proxy from fed.cleanup import set_exit_on_failure_sending, wait_sending +from fed.config import CrossSiloCommConfig from fed.fed_object import FedObject from fed.utils import is_ray_object_refs, setup_logger @@ -40,16 +41,8 @@ def init( party: str = None, tls_config: Dict = None, logging_level: str = 'info', - cross_silo_grpc_retry_policy: Dict = None, - cross_silo_send_max_retries: int = None, - cross_silo_serializing_allowed_list: Dict = None, - cross_silo_send_resource_label: Dict = None, - cross_silo_recv_resource_label: Dict = None, - exit_on_failure_cross_silo_sending: bool = False, - cross_silo_messages_max_size_in_bytes: int = None, - cross_silo_timeout_in_seconds: int = 60, enable_waiting_for_other_parties_ready: bool = False, - grpc_metadata: Dict = None, + cross_silo_comm_config: Optional[CrossSiloCommConfig] = None, **kwargs, ): """ @@ -177,6 +170,8 @@ def init( assert ( 'cert' in tls_config and 'key' in tls_config ), 'Cert or key are not in tls_config.' + + cross_silo_comm_config = cross_silo_comm_config or CrossSiloCommConfig() # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv() @@ -185,14 +180,15 @@ def init( constants.KEY_OF_CURRENT_PARTY_NAME: party, constants.KEY_OF_TLS_CONFIG: tls_config, constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: - cross_silo_serializing_allowed_list, + cross_silo_comm_config.serializing_allowed_list, constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: - cross_silo_messages_max_size_in_bytes, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: cross_silo_timeout_in_seconds, + cross_silo_comm_config.messages_max_size_in_bytes, + constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: + cross_silo_comm_config.timeout_in_seconds, } job_config = { - constants.KEY_OF_GRPC_METADATA : grpc_metadata, + constants.KEY_OF_GRPC_METADATA : cross_silo_comm_config.http_header, } compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)) @@ -209,29 +205,25 @@ def init( ) logger.info(f'Started rayfed with {cluster_config}') - set_exit_on_failure_sending(exit_on_failure_cross_silo_sending) - recv_actor_config = fed_config.CrossSiloProxyConfig( - resource_label=cross_silo_recv_resource_label) + set_exit_on_failure_sending(cross_silo_comm_config.exit_on_sending_failure) # Start recv proxy start_recv_proxy( cluster=cluster, party=party, logging_level=logging_level, tls_config=tls_config, - retry_policy=cross_silo_grpc_retry_policy, - actor_config=recv_actor_config + retry_policy=cross_silo_comm_config.grpc_retry_policy, + actor_config=cross_silo_comm_config ) - send_actor_config = fed_config.CrossSiloProxyConfig( - resource_label=cross_silo_send_resource_label) start_send_proxy( cluster=cluster, party=party, logging_level=logging_level, tls_config=tls_config, - retry_policy=cross_silo_grpc_retry_policy, - max_retries=cross_silo_send_max_retries, - actor_config=send_actor_config + retry_policy=cross_silo_comm_config.grpc_retry_policy, + max_retries=cross_silo_comm_config.proxier_fo_max_retries, + actor_config=cross_silo_comm_config ) if enable_waiting_for_other_parties_ready: diff --git a/fed/barriers.py b/fed/barriers.py index 7fac9a37..7e0acb31 100644 --- a/fed/barriers.py +++ b/fed/barriers.py @@ -201,6 +201,7 @@ def __init__( date_format=constants.RAYFED_DATE_FMT, party_val=party, ) + self._stats = {"send_op_count": 0} self._cluster = cluster self._party = party @@ -220,6 +221,8 @@ async def send( upstream_seq_id, downstream_seq_id, ): + raise RuntimeError("[PaerTest] Anything.") + self._stats["send_op_count"] += 1 assert ( dest_party in self._cluster @@ -378,7 +381,7 @@ def start_recv_proxy( logging_level: str, tls_config=None, retry_policy=None, - actor_config: Optional[fed_config.CrossSiloProxyConfig] = None + actor_config: Optional[fed_config.CrossSiloCommConfig] = None ): # Create RecevrProxyActor @@ -390,8 +393,8 @@ def start_recv_proxy( listen_addr = party_addr['address'] actor_options = copy.deepcopy(_DEFAULT_RECV_PROXY_OPTIONS) - if actor_config is not None and actor_config.resource_label is not None: - actor_options.update({"resources": actor_config.resource_label}) + if actor_config is not None and actor_config.recv_resource_label is not None: + actor_options.update({"resources": actor_config.recv_resource_label}) logger.debug(f"Starting RecvProxyActor with options: {actor_options}") @@ -424,7 +427,7 @@ def start_send_proxy( tls_config: Dict = None, retry_policy=None, max_retries=None, - actor_config: Optional[fed_config.CrossSiloProxyConfig] = None + actor_config: Optional[fed_config.CrossSiloCommConfig] = None ): # Create SendProxyActor global _SEND_PROXY_ACTOR @@ -435,8 +438,8 @@ def start_send_proxy( "max_task_retries": max_retries, "max_restarts": 1, }) - if actor_config is not None and actor_config.resource_label is not None: - actor_options.update({"resources": actor_config.resource_label}) + if actor_config is not None and actor_config.send_resource_label is not None: + actor_options.update({"resources": actor_config.send_resource_label}) logger.debug(f"Starting SendProxyActor with options: {actor_options}") _SEND_PROXY_ACTOR = SendProxyActor.options( diff --git a/fed/config.py b/fed/config.py index 8dabcfe8..53cf29f9 100644 --- a/fed/config.py +++ b/fed/config.py @@ -8,6 +8,7 @@ import fed._private.constants as fed_constants import cloudpickle from typing import Dict, Optional +import json class ClusterConfig: @@ -80,7 +81,7 @@ def get_job_config(): return _job_config -class CrossSiloProxyConfig: +class CrossSiloCommConfig: """A class to store parameters used for Proxy Actor Attributes: @@ -90,16 +91,28 @@ class CrossSiloProxyConfig: def __init__( self, grpc_retry_policy: Dict = None, - send_max_retries: int = None, + proxier_fo_max_retries: int = None, timeout_in_seconds: int = 60, messages_max_size_in_bytes: int = None, + exit_on_sending_failure: Optional[bool] = False, serializing_allowed_list: Optional[Dict[str, str]] = None, send_resource_label: Optional[Dict[str, str]] = None, - recv_resource_label: Optional[Dict[str, str]] = None) -> None: + recv_resource_label: Optional[Dict[str, str]] = None, + http_header: Optional[Dict[str, str]] = None) -> None: self.grpc_retry_policy = grpc_retry_policy - self.send_max_retries = send_max_retries + self.proxier_fo_max_retries = proxier_fo_max_retries self.timeout_in_seconds = timeout_in_seconds self.messages_max_size_in_bytes = messages_max_size_in_bytes + self.exit_on_sending_failure = exit_on_sending_failure self.serializing_allowed_list = serializing_allowed_list self.send_resource_label = send_resource_label self.recv_resource_label = recv_resource_label + self.http_header = http_header + + def __json__(self): + return json.dumps(self.__dict__) + + @classmethod + def from_json(cls, json_str): + data = json.loads(json_str) + return cls(**data) diff --git a/tests/serializations_tests/test_unpickle_with_whitelist.py b/tests/serializations_tests/test_unpickle_with_whitelist.py index c4d0a458..28b4ca81 100644 --- a/tests/serializations_tests/test_unpickle_with_whitelist.py +++ b/tests/serializations_tests/test_unpickle_with_whitelist.py @@ -19,6 +19,8 @@ import multiprocessing import numpy +from fed.config import CrossSiloCommConfig + @fed.remote def generate_wrong_type(): @@ -51,7 +53,9 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_serializing_allowed_list=allowed_list) + cross_silo_comm_config=CrossSiloCommConfig( + serializing_allowed_list=allowed_list + )) # Test passing an allowed type. o1 = generate_allowed_type.party("alice").remote() diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 1479a7d6..4b665d4e 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -19,6 +19,8 @@ import fed import fed._private.compatible_utils as compatible_utils +from fed.config import CrossSiloCommConfig + import signal import os @@ -61,12 +63,15 @@ def run(party, is_inner_party): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } + cross_silo_comm_config = CrossSiloCommConfig( + grpc_retry_policy=retry_policy, + exit_on_sending_failure=True + ) fed.init( cluster=cluster, party=party, logging_level='debug', - cross_silo_grpc_retry_policy=retry_policy, - exit_on_failure_cross_silo_sending=True, + cross_silo_comm_config=cross_silo_comm_config ) o = f.party("alice").remote() diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index 993167ce..b3f1b057 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -18,6 +18,8 @@ import fed._private.compatible_utils as compatible_utils import ray +from fed.config import CrossSiloCommConfig + @fed.remote def dummpy(): @@ -33,7 +35,8 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_messages_max_size_in_bytes=100, + cross_silo_comm_config=CrossSiloCommConfig( + messages_max_size_in_bytes=100) ) def _assert_on_proxy(proxy_actor): diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py index b81dbc4b..e99d3d9a 100644 --- a/tests/test_grpc_options_per_party.py +++ b/tests/test_grpc_options_per_party.py @@ -18,6 +18,8 @@ import fed._private.compatible_utils as compatible_utils import ray +from fed.config import CrossSiloCommConfig + @fed.remote def dummpy(): @@ -39,7 +41,8 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_messages_max_size_in_bytes=100, + cross_silo_comm_config=CrossSiloCommConfig( + messages_max_size_in_bytes=100) ) def _assert_on_send_proxy(proxy_actor): diff --git a/tests/test_party_specific_grpc_options.py b/tests/test_party_specific_grpc_options.py index 12936b72..1c017ca9 100644 --- a/tests/test_party_specific_grpc_options.py +++ b/tests/test_party_specific_grpc_options.py @@ -4,6 +4,8 @@ import fed._private.compatible_utils as compatible_utils import ray +from fed.config import CrossSiloCommConfig + @fed.remote def dummpy(): @@ -29,7 +31,8 @@ def party_grpc_options(party): fed.init( cluster=cluster, party=party, - cross_silo_messages_max_size_in_bytes=100 + cross_silo_comm_config=CrossSiloCommConfig( + messages_max_size_in_bytes=100) ) def _assert_on_proxy(proxy_actor): diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 7c33be6a..9fa5f101 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -20,6 +20,8 @@ import fed._private.compatible_utils as compatible_utils import ray +from fed.config import CrossSiloCommConfig + @fed.remote def f(): @@ -51,7 +53,9 @@ def run(party, is_inner_party): fed.init( cluster=cluster, party=party, - cross_silo_grpc_retry_policy=retry_policy, + cross_silo_comm_config=CrossSiloCommConfig( + grpc_retry_policy=retry_policy + ) ) o = f.party("alice").remote() diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index 7a3aabd2..86f30ddc 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -20,6 +20,8 @@ import fed._private.compatible_utils as compatible_utils import ray +from fed.config import CrossSiloCommConfig + def test_setup_proxy_success(): def run(party): @@ -37,8 +39,10 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_send_resource_label=send_proxy_resources, - cross_silo_recv_resource_label=recv_proxy_resources, + cross_silo_comm_config=CrossSiloCommConfig( + send_resource_label=send_proxy_resources, + recv_resource_label=recv_proxy_resources + ) ) assert ray.get_actor("SendProxyActor") is not None @@ -73,9 +77,11 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_send_resource_label=send_proxy_resources, - cross_silo_recv_resource_label=recv_proxy_resources, - cross_silo_timeout_in_seconds=10, # Quick fail in test + cross_silo_comm_config=CrossSiloCommConfig( + send_resource_label=send_proxy_resources, + recv_resource_label=recv_proxy_resources, + timeout_in_seconds=10 # Quick fail in test + ) ) fed.shutdown() From ccf7d51df49579986f965604fc9514e112959073 Mon Sep 17 00:00:00 2001 From: NKcqx <892670992@qq.com> Date: Tue, 4 Jul 2023 10:14:20 +0800 Subject: [PATCH 03/22] union config in KV Signed-off-by: NKcqx <892670992@qq.com> --- fed/_private/constants.py | 2 ++ fed/_private/serialization_utils.py | 3 ++- fed/api.py | 9 ++------- fed/barriers.py | 23 +++++++++++------------ fed/config.py | 4 ++-- 5 files changed, 19 insertions(+), 22 deletions(-) diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 862fbaeb..3b8e9ddd 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -25,6 +25,8 @@ KEY_OF_TLS_CONFIG = "TLS_CONFIG" +KEY_OF_CROSS_SILO_COMM_CONFIG = "CROSS_SILO_COMM_CONFIG" + KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST = "CROSS_SILO_SERIALIZING_ALLOWED_LIST" # noqa KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES = "CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES" # noqa diff --git a/fed/_private/serialization_utils.py b/fed/_private/serialization_utils.py index fafd173e..ea7abe18 100644 --- a/fed/_private/serialization_utils.py +++ b/fed/_private/serialization_utils.py @@ -63,7 +63,8 @@ def find_class(self, module, name): def _apply_loads_function_with_whitelist(): global _pickle_whitelist - _pickle_whitelist = fed_config.get_cluster_config().serializing_allowed_list + _pickle_whitelist = fed_config.get_job_config() \ + .cross_silo_comm_config.serializing_allowed_list if _pickle_whitelist is None: return diff --git a/fed/api.py b/fed/api.py index 993ab9e9..41f410fa 100644 --- a/fed/api.py +++ b/fed/api.py @@ -179,16 +179,11 @@ def init( constants.KEY_OF_CLUSTER_ADDRESSES: cluster, constants.KEY_OF_CURRENT_PARTY_NAME: party, constants.KEY_OF_TLS_CONFIG: tls_config, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: - cross_silo_comm_config.serializing_allowed_list, - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: - cross_silo_comm_config.messages_max_size_in_bytes, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: - cross_silo_comm_config.timeout_in_seconds, } job_config = { - constants.KEY_OF_GRPC_METADATA : cross_silo_comm_config.http_header, + constants.KEY_OF_CROSS_SILO_COMM_CONFIG: + cross_silo_comm_config, } compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)) diff --git a/fed/barriers.py b/fed/barriers.py index 7e0acb31..d3a369e6 100644 --- a/fed/barriers.py +++ b/fed/barriers.py @@ -28,7 +28,7 @@ from fed._private import constants from fed._private.grpc_options import get_grpc_options, set_max_message_length from fed.cleanup import push_to_sending -from fed.config import get_cluster_config +from fed.config import get_job_config from fed.grpc import fed_pb2, fed_pb2_grpc from fed.utils import setup_logger @@ -136,7 +136,7 @@ async def send_data_grpc( grpc_options = get_grpc_options(retry_policy=retry_policy) if \ grpc_options is None else fed_utils.dict2tuple(grpc_options) tls_enabled = fed_utils.tls_enabled(tls_config) - cluster_config = fed_config.get_cluster_config() + timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds metadata = fed_utils.dict2tuple(metadata) if tls_enabled: ca_cert, private_key, cert_chain = fed_utils.load_cert_config(tls_config) @@ -160,7 +160,7 @@ async def send_data_grpc( ) # wait for downstream's reply response = await stub.SendData( - request, metadata=metadata, timeout=cluster_config.cross_silo_timeout) + request, metadata=metadata, timeout=timeout) logger.debug( f'Received data response from seq_id {downstream_seq_id}, ' f'result: {response.result}.' @@ -177,7 +177,7 @@ async def send_data_grpc( ) # wait for downstream's reply response = await stub.SendData( - request, metadata=metadata, timeout=cluster_config.cross_silo_timeout) + request, metadata=metadata, timeout=timeout) logger.debug( f'Received data response from seq_id {downstream_seq_id} ' f'result: {response.result}.' @@ -207,9 +207,9 @@ def __init__( self._party = party self._tls_config = tls_config self.retry_policy = retry_policy - self._grpc_metadata = fed_config.get_job_config().grpc_metadata - cluster_config = fed_config.get_cluster_config() - set_max_message_length(cluster_config.cross_silo_messages_max_size) + cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config + self._grpc_metadata = cross_silo_comm_config.http_header + set_max_message_length(cross_silo_comm_config.messages_max_size_in_bytes) async def is_ready(self): return True @@ -221,7 +221,6 @@ async def send( upstream_seq_id, downstream_seq_id, ): - raise RuntimeError("[PaerTest] Anything.") self._stats["send_op_count"] += 1 assert ( @@ -305,8 +304,8 @@ def __init__( self._party = party self._tls_config = tls_config self.retry_policy = retry_policy - config = fed_config.get_cluster_config() - set_max_message_length(config.cross_silo_messages_max_size) + cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config + set_max_message_length(cross_silo_comm_config.messages_max_size_in_bytes) # Workaround the threading coordinations # Flag to see whether grpc server starts @@ -408,7 +407,7 @@ def start_recv_proxy( retry_policy=retry_policy, ) recver_proxy_actor.run_grpc_server.remote() - timeout = get_cluster_config().cross_silo_timeout + timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout) assert server_state[0], server_state[1] logger.info("RecverProxy has successfully created.") @@ -452,7 +451,7 @@ def start_send_proxy( logging_level=logging_level, retry_policy=retry_policy, ) - timeout = get_cluster_config().cross_silo_timeout + timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout) logger.info("SendProxyActor has successfully created.") diff --git a/fed/config.py b/fed/config.py index 53cf29f9..f8ee7415 100644 --- a/fed/config.py +++ b/fed/config.py @@ -49,8 +49,8 @@ def __init__(self, raw_bytes: bytes) -> None: self._data = cloudpickle.loads(raw_bytes) @property - def grpc_metadata(self): - return self._data.get(fed_constants.KEY_OF_GRPC_METADATA, {}) + def cross_silo_comm_config(self): + return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG, {}) # A module level cache for the cluster configurations. From a046d2763185545a0d253ae4ac954b021f50f148 Mon Sep 17 00:00:00 2001 From: NKcqx <892670992@qq.com> Date: Tue, 4 Jul 2023 10:42:21 +0800 Subject: [PATCH 04/22] update docstr Signed-off-by: NKcqx <892670992@qq.com> --- fed/api.py | 40 ++-------------------------------------- fed/config.py | 39 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/fed/api.py b/fed/api.py index 41f410fa..ab714934 100644 --- a/fed/api.py +++ b/fed/api.py @@ -104,48 +104,12 @@ def init( "cert": "bob's server cert", "key": "bob's server cert key", } - logging_level: optional; the logging level, could be `debug`, `info`, `warning`, `error`, `critical`, not case sensititive. - cross_silo_grpc_retry_policy: a dict descibes the retry policy for - cross silo rpc call. If None, the following default retry policy - will be used. More details please refer to - `retry-policy `_. # noqa - - .. code:: python - { - "maxAttempts": 4, - "initialBackoff": "0.1s", - "maxBackoff": "1s", - "backoffMultiplier": 2, - "retryableStatusCodes": [ - "UNAVAILABLE" - ] - } - cross_silo_send_max_retries: the max retries for sending data cross silo. - cross_silo_serializing_allowed_list: The package or class list allowed for - serializing(deserializating) cross silos. It's used for avoiding pickle - deserializing execution attack when crossing solis. - cross_silo_send_resource_label: Customized resource label, the SendProxyActor - will be scheduled based on the declared resource label. For example, - when setting to `{"my_label": 1}`, then the SendProxyActor will be started - only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. - cross_silo_recv_resource_label: Customized resource label, the RecverProxyActor - will be scheduled based on the declared resource label. For example, - when setting to `{"my_label": 1}`, then the RecverProxyActor will be started - only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. - exit_on_failure_cross_silo_sending: whether exit when failure on - cross-silo sending. If True, a SIGTERM will be signaled to self - if failed to sending cross-silo data. - cross_silo_messages_max_size_in_bytes: The maximum length in bytes of - cross-silo messages. - If None, the default value of 500 MB is specified. - cross_silo_timeout_in_seconds: The timeout in seconds of a cross-silo RPC call. - It's 60 by default. enable_waiting_for_other_parties_ready: ping other parties until they are all ready if True. - grpc_metadata: optional; The metadata sent with the grpc request. This won't override - basic tcp headers, such as `user-agent`, but aggregate them together. + cross_silo_comm_config: Cross-silo communication related config, supported + configs can refer to CrossSiloCommConfig in config.py Examples: >>> import fed diff --git a/fed/config.py b/fed/config.py index f8ee7415..c5f853e7 100644 --- a/fed/config.py +++ b/fed/config.py @@ -85,8 +85,43 @@ class CrossSiloCommConfig: """A class to store parameters used for Proxy Actor Attributes: - resource_label: The customized resources for the actor. This will be - filled into the "resource" field of Ray ActorClass.options. + grpc_retry_policy: a dict descibes the retry policy for + cross silo rpc call. If None, the following default retry policy + will be used. More details please refer to + `retry-policy `_. # noqa + + .. code:: python + { + "maxAttempts": 4, + "initialBackoff": "0.1s", + "maxBackoff": "1s", + "backoffMultiplier": 2, + "retryableStatusCodes": [ + "UNAVAILABLE" + ] + } + proxier_fo_max_retries: The max restart times for the send proxy. + serializing_allowed_list: The package or class list allowed for + serializing(deserializating) cross silos. It's used for avoiding pickle + deserializing execution attack when crossing solis. + send_resource_label: Customized resource label, the SendProxyActor + will be scheduled based on the declared resource label. For example, + when setting to `{"my_label": 1}`, then the SendProxyActor will be started + only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. + recv_resource_label: Customized resource label, the RecverProxyActor + will be scheduled based on the declared resource label. For example, + when setting to `{"my_label": 1}`, then the RecverProxyActor will be started + only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. + exit_on_sending_failure: whether exit when failure on + cross-silo sending. If True, a SIGTERM will be signaled to self + if failed to sending cross-silo data. + messages_max_size_in_bytes: The maximum length in bytes of + cross-silo messages. + If None, the default value of 500 MB is specified. + timeout_in_seconds: The timeout in seconds of a cross-silo RPC call. + It's 60 by default. + http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. This won't override + basic tcp headers, such as `user-agent`, but concat them together. """ def __init__( self, From aa397a8af8582a897c507084a418115cc242da2d Mon Sep 17 00:00:00 2001 From: NKcqx <892670992@qq.com> Date: Mon, 10 Jul 2023 15:41:20 +0800 Subject: [PATCH 05/22] tmp save --- fed/api.py | 30 +++++++++++++++++++----------- fed/barriers.py | 33 ++++++++++++++++++++++----------- fed/config.py | 47 +++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 84 insertions(+), 26 deletions(-) diff --git a/fed/api.py b/fed/api.py index ab714934..845ecd9a 100644 --- a/fed/api.py +++ b/fed/api.py @@ -42,7 +42,8 @@ def init( tls_config: Dict = None, logging_level: str = 'info', enable_waiting_for_other_parties_ready: bool = False, - cross_silo_comm_config: Optional[CrossSiloCommConfig] = None, + global_cross_silo_comm_config: Optional[CrossSiloCommConfig] = None, + dest_party_comm_config: Optional[Dict[CrossSiloCommConfig]] = None, **kwargs, ): """ @@ -108,8 +109,16 @@ def init( `warning`, `error`, `critical`, not case sensititive. enable_waiting_for_other_parties_ready: ping other parties until they are all ready if True. - cross_silo_comm_config: Cross-silo communication related config, supported - configs can refer to CrossSiloCommConfig in config.py + global_cross_silo_comm_config: Global cross-silo communication related + config that are applied to all connections. Supported configs + can refer to CrossSiloCommConfig in config.py. + dest_party_comm_config: Communication config for the destination party + specifed by the key. E.g. + .. code:: python + { + 'alice': alice_CrossSiloCommConfig, + 'bob': bob_CrossSiloCommConfig + } Examples: >>> import fed @@ -135,7 +144,7 @@ def init( 'cert' in tls_config and 'key' in tls_config ), 'Cert or key are not in tls_config.' - cross_silo_comm_config = cross_silo_comm_config or CrossSiloCommConfig() + global_cross_silo_comm_config = global_cross_silo_comm_config or CrossSiloCommConfig() # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv() @@ -147,7 +156,7 @@ def init( job_config = { constants.KEY_OF_CROSS_SILO_COMM_CONFIG: - cross_silo_comm_config, + global_cross_silo_comm_config, } compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)) @@ -164,15 +173,15 @@ def init( ) logger.info(f'Started rayfed with {cluster_config}') - set_exit_on_failure_sending(cross_silo_comm_config.exit_on_sending_failure) + set_exit_on_failure_sending(global_cross_silo_comm_config.exit_on_sending_failure) # Start recv proxy start_recv_proxy( cluster=cluster, party=party, logging_level=logging_level, tls_config=tls_config, - retry_policy=cross_silo_comm_config.grpc_retry_policy, - actor_config=cross_silo_comm_config + proxy_cls=None, + proxy_config=global_cross_silo_comm_config ) start_send_proxy( @@ -180,9 +189,8 @@ def init( party=party, logging_level=logging_level, tls_config=tls_config, - retry_policy=cross_silo_comm_config.grpc_retry_policy, - max_retries=cross_silo_comm_config.proxier_fo_max_retries, - actor_config=cross_silo_comm_config + proxy_cls=None, + proxy_config=global_cross_silo_comm_config # retry_policy=cross_silo_comm_config.grpc_retry_policy, ) if enable_waiting_for_other_parties_ready: diff --git a/fed/barriers.py b/fed/barriers.py index d3a369e6..073933fa 100644 --- a/fed/barriers.py +++ b/fed/barriers.py @@ -194,6 +194,7 @@ def __init__( tls_config: Dict = None, logging_level: str = None, retry_policy: Dict = None, + proxy_cls = None ): setup_logger( logging_level=logging_level, @@ -221,6 +222,7 @@ async def send( upstream_seq_id, downstream_seq_id, ): + # proxy_cls.send() self._stats["send_op_count"] += 1 assert ( @@ -291,7 +293,9 @@ def __init__( party: str, logging_level: str, tls_config=None, - retry_policy: Dict = None, + proxy_cls = None, + # retry_policy: Dict = None, + ): setup_logger( logging_level=logging_level, @@ -317,6 +321,8 @@ def __init__( self._lock = threading.Lock() async def run_grpc_server(self): + # proxy_cls.run_grpc_server() + try: port = self._listen_addr[self._listen_addr.index(':') + 1 :] await _run_grpc_server( @@ -340,6 +346,10 @@ async def is_ready(self): return self._server_ready_future.result() async def get_data(self, src_aprty, upstream_seq_id, curr_seq_id): + # subscriber + + # proxy_cls.get_data() # get from broker channel + self._stats["receive_op_count"] += 1 data_log_msg = f"data for {curr_seq_id} from {upstream_seq_id} of {src_aprty}" logger.debug(f"Getting {data_log_msg}") @@ -380,7 +390,7 @@ def start_recv_proxy( logging_level: str, tls_config=None, retry_policy=None, - actor_config: Optional[fed_config.CrossSiloCommConfig] = None + proxy_config: Optional[fed_config.CrossSiloCommConfig] = None ): # Create RecevrProxyActor @@ -392,8 +402,8 @@ def start_recv_proxy( listen_addr = party_addr['address'] actor_options = copy.deepcopy(_DEFAULT_RECV_PROXY_OPTIONS) - if actor_config is not None and actor_config.recv_resource_label is not None: - actor_options.update({"resources": actor_config.recv_resource_label}) + if proxy_config is not None and proxy_config.recv_resource_label is not None: + actor_options.update({"resources": proxy_config.recv_resource_label}) logger.debug(f"Starting RecvProxyActor with options: {actor_options}") @@ -425,20 +435,20 @@ def start_send_proxy( logging_level: str, tls_config: Dict = None, retry_policy=None, - max_retries=None, - actor_config: Optional[fed_config.CrossSiloCommConfig] = None + proxy_cls=None, + proxy_config: Optional[fed_config.CrossSiloCommConfig] = None ): # Create SendProxyActor global _SEND_PROXY_ACTOR actor_options = copy.deepcopy(_DEFAULT_SEND_PROXY_OPTIONS) - if max_retries is not None: + if proxy_config and proxy_config.proxier_fo_max_retries: actor_options.update({ - "max_task_retries": max_retries, + "max_task_retries": proxy_config.proxier_fo_max_retries, "max_restarts": 1, }) - if actor_config is not None and actor_config.send_resource_label is not None: - actor_options.update({"resources": actor_config.send_resource_label}) + if proxy_config and proxy_config.send_resource_label: + actor_options.update({"resources": proxy_config.send_resource_label}) logger.debug(f"Starting SendProxyActor with options: {actor_options}") _SEND_PROXY_ACTOR = SendProxyActor.options( @@ -449,7 +459,8 @@ def start_send_proxy( party=party, tls_config=tls_config, logging_level=logging_level, - retry_policy=retry_policy, + # retry_policy=retry_policy, + # starter=server_starter ) timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout) diff --git a/fed/config.py b/fed/config.py index c5f853e7..4c2b19cc 100644 --- a/fed/config.py +++ b/fed/config.py @@ -120,12 +120,12 @@ class CrossSiloCommConfig: If None, the default value of 500 MB is specified. timeout_in_seconds: The timeout in seconds of a cross-silo RPC call. It's 60 by default. - http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. This won't override - basic tcp headers, such as `user-agent`, but concat them together. + http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. + This won't override basic tcp headers, such as `user-agent`, but concat + them together. """ def __init__( self, - grpc_retry_policy: Dict = None, proxier_fo_max_retries: int = None, timeout_in_seconds: int = 60, messages_max_size_in_bytes: int = None, @@ -134,7 +134,6 @@ def __init__( send_resource_label: Optional[Dict[str, str]] = None, recv_resource_label: Optional[Dict[str, str]] = None, http_header: Optional[Dict[str, str]] = None) -> None: - self.grpc_retry_policy = grpc_retry_policy self.proxier_fo_max_retries = proxier_fo_max_retries self.timeout_in_seconds = timeout_in_seconds self.messages_max_size_in_bytes = messages_max_size_in_bytes @@ -151,3 +150,43 @@ def __json__(self): def from_json(cls, json_str): data = json.loads(json_str) return cls(**data) + + +class CrossSiloGRPCConfig(CrossSiloCommConfig): + """A class to store parameters used for GRPC communication + + Attributes: + grpc_retry_policy: + grpc_channel_options: A list of tuples to store GRPC channel options, + e.g. [ + ('grpc.enable_retries', 1), + ('grpc.max_send_message_length', 50 * 1024 * 1024) + ] + """ + def __init__(self, + grpc_channel_options, + grpc_retry_policy, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.grpc_retry_policy = grpc_retry_policy + self.grpc_channel_options = grpc_channel_options + + +class CrossSiloBRPCConfig(CrossSiloCommConfig): + """A class to store parameters used for GRPC communication + + Attributes: + grpc_retry_policy: + grpc_channel_options: A list of tuples to store GRPC channel options, + e.g. [ + ('grpc.enable_retries', 1), + ('grpc.max_send_message_length', 50 * 1024 * 1024) + ] + """ + def __init__(self, + brpc_options, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.brpc_options = brpc_options From 40cec6513b8b9eadc3f1077b700abebfb80a05b0 Mon Sep 17 00:00:00 2001 From: paer Date: Mon, 10 Jul 2023 22:39:09 +0800 Subject: [PATCH 06/22] fix conflicts --- fed/api.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/fed/api.py b/fed/api.py index 7d27e3da..1fc96c31 100644 --- a/fed/api.py +++ b/fed/api.py @@ -180,19 +180,8 @@ def init( ) logger.info(f'Started rayfed with {cluster_config}') -<<<<<<< HEAD - set_exit_on_failure_sending(global_cross_silo_comm_config.exit_on_sending_failure) -||||||| e98fd36 - set_exit_on_failure_sending(exit_on_failure_cross_silo_sending) - recv_actor_config = fed_config.ProxyActorConfig( - resource_label=cross_silo_recv_resource_label) -======= get_global_context().get_cleanup_manager().start( - exit_when_failure_sending=exit_on_failure_cross_silo_sending) - - recv_actor_config = fed_config.ProxyActorConfig( - resource_label=cross_silo_recv_resource_label) ->>>>>>> main + exit_when_failure_sending=global_cross_silo_comm_config.exit_on_sending_failure) # Start recv proxy start_recv_proxy( cluster=cluster, From ed18a07a7dbacb8d354514ab9fc40d1a80400b9d Mon Sep 17 00:00:00 2001 From: paer Date: Mon, 10 Jul 2023 22:42:02 +0800 Subject: [PATCH 07/22] fix conflicts --- fed/api.py | 2 -- fed/proxy/barriers.py | 8 -------- 2 files changed, 10 deletions(-) diff --git a/fed/api.py b/fed/api.py index 1fc96c31..b1b6bd02 100644 --- a/fed/api.py +++ b/fed/api.py @@ -34,7 +34,6 @@ start_recv_proxy, start_send_proxy, ) -from fed.cleanup import set_exit_on_failure_sending, wait_sending from fed.config import CrossSiloCommConfig from fed.fed_object import FedObject @@ -50,7 +49,6 @@ def init( logging_level: str = 'info', enable_waiting_for_other_parties_ready: bool = False, global_cross_silo_comm_config: Optional[CrossSiloCommConfig] = None, - dest_party_comm_config: Optional[Dict[CrossSiloCommConfig]] = None, **kwargs, ): """ diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 77aa1d31..5626b322 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -27,15 +27,7 @@ import fed.utils as fed_utils from fed._private import constants from fed._private.grpc_options import get_grpc_options, set_max_message_length -<<<<<<< HEAD:fed/barriers.py -from fed.cleanup import push_to_sending from fed.config import get_job_config -||||||| e98fd36:fed/barriers.py -from fed.cleanup import push_to_sending -from fed.config import get_cluster_config -======= -from fed.config import get_cluster_config ->>>>>>> main:fed/proxy/barriers.py from fed.grpc import fed_pb2, fed_pb2_grpc from fed.utils import setup_logger from fed._private.global_context import get_global_context From 51e4e70e137b720686ec07ab40e67f553fa6999e Mon Sep 17 00:00:00 2001 From: paer Date: Mon, 10 Jul 2023 23:45:20 +0800 Subject: [PATCH 08/22] Pluggable cross_silo rpc impl --- fed/api.py | 8 +- fed/config.py | 2 +- fed/proxy/barriers.py | 275 +++++++++------------------------------- fed/proxy/grpc_proxy.py | 253 ++++++++++++++++++++++++++++++++++++ 4 files changed, 324 insertions(+), 214 deletions(-) create mode 100644 fed/proxy/grpc_proxy.py diff --git a/fed/api.py b/fed/api.py index b1b6bd02..184738f1 100644 --- a/fed/api.py +++ b/fed/api.py @@ -33,6 +33,7 @@ send, start_recv_proxy, start_send_proxy, + SendProxy ) from fed.config import CrossSiloCommConfig @@ -48,6 +49,7 @@ def init( tls_config: Dict = None, logging_level: str = 'info', enable_waiting_for_other_parties_ready: bool = False, + send_proxy_cls: SendProxy = None, global_cross_silo_comm_config: Optional[CrossSiloCommConfig] = None, **kwargs, ): @@ -190,12 +192,16 @@ def init( proxy_config=global_cross_silo_comm_config ) + if send_proxy_cls is None: + from fed.proxy.grpc_proxy import GrpcSendProxy + send_proxy_cls = GrpcSendProxy + start_send_proxy( cluster=cluster, party=party, logging_level=logging_level, tls_config=tls_config, - proxy_cls=None, + proxy_cls=send_proxy_cls, proxy_config=global_cross_silo_comm_config # retry_policy=cross_silo_comm_config.grpc_retry_policy, ) diff --git a/fed/config.py b/fed/config.py index 4c2b19cc..9d3f53c3 100644 --- a/fed/config.py +++ b/fed/config.py @@ -152,7 +152,7 @@ def from_json(cls, json_str): return cls(**data) -class CrossSiloGRPCConfig(CrossSiloCommConfig): +class CrossSiloGrpcCommConfig(CrossSiloCommConfig): """A class to store parameters used for GRPC communication Attributes: diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 5626b322..bebd1e40 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import asyncio import logging import threading @@ -26,9 +27,8 @@ import fed.config as fed_config import fed.utils as fed_utils from fed._private import constants -from fed._private.grpc_options import get_grpc_options, set_max_message_length -from fed.config import get_job_config -from fed.grpc import fed_pb2, fed_pb2_grpc + +from fed.config import get_job_config, CrossSiloCommConfig from fed.utils import setup_logger from fed._private.global_context import get_global_context @@ -60,94 +60,57 @@ def pop_from_two_dim_dict(the_dict, key_a, key_b): return the_dict[key_a].pop(key_b) -class SendDataService(fed_pb2_grpc.GrpcServiceServicer): - def __init__(self, all_events, all_data, party, lock): - self._events = all_events - self._all_data = all_data +class SendProxy(abc.ABC): + def __init__( + self, + cluster: Dict, + party: str, + proxy_config = None) -> None: + self._cluster = cluster self._party = party - self._lock = lock + self._proxy_config = proxy_config - async def SendData(self, request, context): - upstream_seq_id = request.upstream_seq_id - downstream_seq_id = request.downstream_seq_id - logger.debug( - f'Received a grpc data request from {upstream_seq_id} to ' - f'{downstream_seq_id}.' - ) + + @abc.abstractmethod + async def send( + self, + dest_party, + data, + upstream_seq_id, + downstream_seq_id + ): + pass - with self._lock: - add_two_dim_dict( - self._all_data, upstream_seq_id, downstream_seq_id, request.data - ) - if not key_exists_in_two_dim_dict( - self._events, upstream_seq_id, downstream_seq_id - ): - event = asyncio.Event() - add_two_dim_dict( - self._events, upstream_seq_id, downstream_seq_id, event - ) - event = get_from_two_dim_dict(self._events, upstream_seq_id, downstream_seq_id) - event.set() - logger.debug(f"Event set for {upstream_seq_id}") - return fed_pb2.SendDataResponse(result="OK") - - -async def _run_grpc_server( - port, event, all_data, party, lock, - server_ready_future, tls_config=None, grpc_options=None -): - server = grpc.aio.server(options=grpc_options) - fed_pb2_grpc.add_GrpcServiceServicer_to_server( - SendDataService(event, all_data, party, lock), server - ) + async def is_ready(): + return True - tls_enabled = fed_utils.tls_enabled(tls_config) - if tls_enabled: - ca_cert, private_key, cert_chain = fed_utils.load_cert_config(tls_config) - server_credentials = grpc.ssl_server_credentials( - [(private_key, cert_chain)], - root_certificates=ca_cert, - require_client_auth=ca_cert is not None, - ) - server.add_secure_port(f'[::]:{port}', server_credentials) - else: - server.add_insecure_port(f'[::]:{port}') +class RecvProxy(abc.ABC): + def __init__( + self, + listen_addr: str, + party: str, + tls_config: Dict, + proxy_config: CrossSiloCommConfig) -> None: + self._listen_addr = listen_addr + self._party = party + self._tls_config = tls_config + self._proxy_config = proxy_config - msg = f"Succeeded to add port {port}." - await server.start() - logger.info( - f'Successfully start Grpc service with{"out" if not tls_enabled else ""} ' - 'credentials.' - ) - server_ready_future.set_result((True, msg)) - await server.wait_for_termination() + @abc.abstractmethod + def start(self): + pass -async def send_data_grpc( - data, - stub, - upstream_seq_id, - downstream_seq_id, - metadata=None, -): - cluster_config = fed_config.get_cluster_config() - data = cloudpickle.dumps(data) - request = fed_pb2.SendDataRequest( - data=data, - upstream_seq_id=str(upstream_seq_id), - downstream_seq_id=str(downstream_seq_id), - ) - # Waiting for the reply from downstream. - response = await stub.SendData( - request, - metadata=fed_utils.dict2tuple(metadata), - timeout=cluster_config.cross_silo_timeout, - ) - logger.debug( - f'Received data response from seq_id {downstream_seq_id}, ' - f'result: {response.result}.' - ) - return response.result + @abc.abstractmethod + async def get_data( + self, + src_party, + upstream_seq_id, + curr_seq_id): + pass + + async def is_ready(self): + return True @ray.remote @@ -158,7 +121,6 @@ def __init__( party: str, tls_config: Dict = None, logging_level: str = None, - retry_policy: Dict = None, proxy_cls = None ): setup_logger( @@ -172,15 +134,11 @@ def __init__( self._cluster = cluster self._party = party self._tls_config = tls_config - self.retry_policy = retry_policy cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config - self._grpc_metadata = cross_silo_comm_config.http_header - set_max_message_length(cross_silo_comm_config.messages_max_size_in_bytes) - # Mapping the destination party name to the reused client stub. - self._stubs = {} + self.proxy_instance: SendProxy = proxy_cls(cluster, party, cross_silo_comm_config) async def is_ready(self): - return True + return self.proxy_instance.is_ready() async def send( self, @@ -189,8 +147,6 @@ async def send( upstream_seq_id, downstream_seq_id, ): - # proxy_cls.send() - self._stats["send_op_count"] += 1 assert ( dest_party in self._cluster @@ -203,68 +159,18 @@ async def send( f'Sending {send_log_msg} with{"out" if not self._tls_config else ""}' ' credentials.' ) - dest_addr = self._cluster[dest_party]['address'] - dest_party_grpc_config = self.setup_grpc_config(dest_party) try: - tls_enabled = fed_utils.tls_enabled(self._tls_config) - grpc_options = dest_party_grpc_config['grpc_options'] - grpc_options = get_grpc_options(retry_policy=self.retry_policy) if \ - grpc_options is None else fed_utils.dict2tuple(grpc_options) - - if dest_party not in self._stubs: - if tls_enabled: - ca_cert, private_key, cert_chain = fed_utils.load_cert_config( - self._tls_config) - credentials = grpc.ssl_channel_credentials( - certificate_chain=cert_chain, - private_key=private_key, - root_certificates=ca_cert, - ) - channel = grpc.aio.secure_channel( - dest_addr, credentials, options=grpc_options) - else: - channel = grpc.aio.insecure_channel(dest_addr, options=grpc_options) - stub = fed_pb2_grpc.GrpcServiceStub(channel) - self._stubs[dest_party] = stub - - response = await send_data_grpc( - data=data, - stub=self._stubs[dest_party], - upstream_seq_id=upstream_seq_id, - downstream_seq_id=downstream_seq_id, - metadata=dest_party_grpc_config['grpc_metadata'], - ) + response = await self.proxy_instance.send(dest_party, data, upstream_seq_id, downstream_seq_id) except Exception as e: logger.error(f'Failed to {send_log_msg}, error: {e}') return False logger.debug(f"Succeeded to send {send_log_msg}. Response is {response}") return True # True indicates it's sent successfully. - def setup_grpc_config(self, dest_party): - dest_party_grpc_config = {} - global_grpc_metadata = ( - dict(self._grpc_metadata) if self._grpc_metadata is not None else {} - ) - dest_party_grpc_metadata = dict( - self._cluster[dest_party].get('grpc_metadata', {}) - ) - # merge grpc metadata - dest_party_grpc_config['grpc_metadata'] = { - **global_grpc_metadata, **dest_party_grpc_metadata} - - global_grpc_options = dict(get_grpc_options(self.retry_policy)) - dest_party_grpc_options = dict( - self._cluster[dest_party].get('grpc_options', {}) - ) - dest_party_grpc_config['grpc_options'] = { - **global_grpc_options, **dest_party_grpc_options} - return dest_party_grpc_config async def _get_stats(self): return self._stats - async def _get_grpc_options(self): - return get_grpc_options() async def _get_cluster_info(self): return self._cluster @@ -279,8 +185,6 @@ def __init__( logging_level: str, tls_config=None, proxy_cls = None, - # retry_policy: Dict = None, - ): setup_logger( logging_level=logging_level, @@ -292,76 +196,25 @@ def __init__( self._listen_addr = listen_addr self._party = party self._tls_config = tls_config - self.retry_policy = retry_policy cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config set_max_message_length(cross_silo_comm_config.messages_max_size_in_bytes) - # Workaround the threading coordinations - - # Flag to see whether grpc server starts - self._server_ready_future = asyncio.Future() - - # All events for grpc waitting usage. - self._events = {} # map from (upstream_seq_id, downstream_seq_id) to event - self._all_data = {} # map from (upstream_seq_id, downstream_seq_id) to data - self._lock = threading.Lock() - - async def run_grpc_server(self): - # proxy_cls.run_grpc_server() - - try: - port = self._listen_addr[self._listen_addr.index(':') + 1 :] - await _run_grpc_server( - port, - self._events, - self._all_data, - self._party, - self._lock, - self._server_ready_future, - self._tls_config, - get_grpc_options(self.retry_policy), - ) - except RuntimeError as err: - msg = f'Grpc server failed to listen to port: {port}' \ - f' Try another port by setting `listen_addr` into `cluster` config' \ - f' when calling `fed.init`. Grpc error msg: {err}' - self._server_ready_future.set_result((False, msg)) + self._proxy_instance: RecvProxy = proxy_cls(listen_addr, party, cross_silo_comm_config) + async def start(self): + await self._proxy_instance.start() + async def is_ready(self): - await self._server_ready_future - return self._server_ready_future.result() + return self._proxy_instance.is_ready() - async def get_data(self, src_aprty, upstream_seq_id, curr_seq_id): - # subscriber - - # proxy_cls.get_data() # get from broker channel - - self._stats["receive_op_count"] += 1 - data_log_msg = f"data for {curr_seq_id} from {upstream_seq_id} of {src_aprty}" - logger.debug(f"Getting {data_log_msg}") - with self._lock: - if not key_exists_in_two_dim_dict( - self._events, upstream_seq_id, curr_seq_id - ): - add_two_dim_dict( - self._events, upstream_seq_id, curr_seq_id, asyncio.Event() - ) - curr_event = get_from_two_dim_dict(self._events, upstream_seq_id, curr_seq_id) - await curr_event.wait() - logging.debug(f"Waited {data_log_msg}.") - with self._lock: - data = pop_from_two_dim_dict(self._all_data, upstream_seq_id, curr_seq_id) - pop_from_two_dim_dict(self._events, upstream_seq_id, curr_seq_id) - - # NOTE(qwang): This is used to avoid the conflict with pickle5 in Ray. - import fed._private.serialization_utils as fed_ser_utils - fed_ser_utils._apply_loads_function_with_whitelist() + + async def get_data(self, src_party, upstream_seq_id, curr_seq_id): + self._stats["receive_op_count"] += 1 + data = await self._proxy_instance.get_data(src_party, upstream_seq_id, curr_seq_id) return cloudpickle.loads(data) async def _get_stats(self): return self._stats - async def _get_grpc_options(self): - return get_grpc_options() _DEFAULT_RECV_PROXY_OPTIONS = { @@ -374,7 +227,7 @@ def start_recv_proxy( party: str, logging_level: str, tls_config=None, - retry_policy=None, + proxy_cls=None, proxy_config: Optional[fed_config.CrossSiloCommConfig] = None ): @@ -399,9 +252,9 @@ def start_recv_proxy( party=party, tls_config=tls_config, logging_level=logging_level, - retry_policy=retry_policy, + proxy_cls=proxy_cls ) - recver_proxy_actor.run_grpc_server.remote() + recver_proxy_actor.start.remote() timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout) assert server_state[0], server_state[1] @@ -419,7 +272,6 @@ def start_send_proxy( party: str, logging_level: str, tls_config: Dict = None, - retry_policy=None, proxy_cls=None, proxy_config: Optional[fed_config.CrossSiloCommConfig] = None ): @@ -444,8 +296,7 @@ def start_send_proxy( party=party, tls_config=tls_config, logging_level=logging_level, - # retry_policy=retry_policy, - # starter=server_starter + proxy_cls=proxy_cls ) timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout) diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc_proxy.py new file mode 100644 index 00000000..9124c079 --- /dev/null +++ b/fed/proxy/grpc_proxy.py @@ -0,0 +1,253 @@ +import asyncio +import cloudpickle +import grpc +import logging +import threading +from typing import Dict + + +import fed.config as fed_config +import fed.utils as fed_utils + +from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig +from fed._private.grpc_options import get_grpc_options, set_max_message_length +from fed.proxy.barriers import ( + add_two_dim_dict, + get_from_two_dim_dict, + pop_from_two_dim_dict, + key_exists_in_two_dim_dict, + SendProxy, + RecvProxy +) +from fed.grpc import fed_pb2, fed_pb2_grpc + + +logger = logging.getLogger(__name__) + + + +class GrpcSendProxy(SendProxy): + def __init__(self, cluster: Dict, party: str, proxy_config=None) -> None: + super().__init__(cluster, party, proxy_config) + self._grpc_metadata = proxy_config.http_header + set_max_message_length(proxy_config.messages_max_size_in_bytes) + # Mapping the destination party name to the reused client stub. + self._stubs = {} + + async def send( + self, + dest_party, + data, + upstream_seq_id, + downstream_seq_id): + dest_addr = self._cluster[dest_party]['address'] + dest_party_grpc_config = self.setup_grpc_config(dest_party) + tls_enabled = fed_utils.tls_enabled(self._tls_config) + grpc_options = dest_party_grpc_config['grpc_options'] + grpc_options = get_grpc_options(retry_policy=self.retry_policy) if \ + grpc_options is None else fed_utils.dict2tuple(grpc_options) + if dest_party not in self._stubs: + if tls_enabled: + ca_cert, private_key, cert_chain = fed_utils.load_cert_config( + self._tls_config) + credentials = grpc.ssl_channel_credentials( + certificate_chain=cert_chain, + private_key=private_key, + root_certificates=ca_cert, + ) + channel = grpc.aio.secure_channel( + dest_addr, credentials, options=grpc_options) + else: + channel = grpc.aio.insecure_channel(dest_addr, options=grpc_options) + stub = fed_pb2_grpc.GrpcServiceStub(channel) + self._stubs[dest_party] = stub + + response = await send_data_grpc( + data=data, + stub=self._stubs[dest_party], + upstream_seq_id=upstream_seq_id, + downstream_seq_id=downstream_seq_id, + metadata=dest_party_grpc_config['grpc_metadata'], + ) + return response + + def setup_grpc_config(self, dest_party): + dest_party_grpc_config = {} + global_grpc_metadata = ( + dict(self._grpc_metadata) if self._grpc_metadata is not None else {} + ) + dest_party_grpc_metadata = dict( + self._cluster[dest_party].get('grpc_metadata', {}) + ) + # merge grpc metadata + dest_party_grpc_config['grpc_metadata'] = { + **global_grpc_metadata, **dest_party_grpc_metadata} + + global_grpc_options = dict(get_grpc_options(self.retry_policy)) + dest_party_grpc_options = dict( + self._cluster[dest_party].get('grpc_options', {}) + ) + dest_party_grpc_config['grpc_options'] = { + **global_grpc_options, **dest_party_grpc_options} + return dest_party_grpc_config + + async def _get_grpc_options(self): + return get_grpc_options() + + + +async def send_data_grpc( + data, + stub, + upstream_seq_id, + downstream_seq_id, + metadata=None, +): + cluster_config = fed_config.get_cluster_config() + data = cloudpickle.dumps(data) + request = fed_pb2.SendDataRequest( + data=data, + upstream_seq_id=str(upstream_seq_id), + downstream_seq_id=str(downstream_seq_id), + ) + # Waiting for the reply from downstream. + response = await stub.SendData( + request, + metadata=fed_utils.dict2tuple(metadata), + timeout=cluster_config.cross_silo_timeout, + ) + logger.debug( + f'Received data response from seq_id {downstream_seq_id}, ' + f'result: {response.result}.' + ) + return response.result + + +class GrpcRecvProxy(RecvProxy): + def __init__(self, listen_addr: str, party: str, proxy_config: CrossSiloCommConfig) -> None: + super().__init__(listen_addr, party, proxy_config) + # Flag to see whether grpc server starts + self._server_ready_future = asyncio.Future() + self._retry_policy = None + if isinstance(proxy_config, CrossSiloGrpcCommConfig): + self._retry_policy = proxy_config.grpc_retry_policy + + # All events for grpc waitting usage. + self._events = {} # map from (upstream_seq_id, downstream_seq_id) to event + self._all_data = {} # map from (upstream_seq_id, downstream_seq_id) to data + self._lock = threading.Lock() + + async def start(self): + port = self._listen_addr[self._listen_addr.index(':') + 1 :] + try: + await _run_grpc_server( + port, + self._events, + self._all_data, + self._party, + self._lock, + self._server_ready_future, + self._tls_config, + get_grpc_options(self._retry_policy), + ) + except RuntimeError as err: + msg = f'Grpc server failed to listen to port: {port}' \ + f' Try another port by setting `listen_addr` into `cluster` config' \ + f' when calling `fed.init`. Grpc error msg: {err}' + self._server_ready_future.set_result((False, msg)) + + + async def is_ready(self): + await self._server_ready_future + return self._server_ready_future.result() + + async def get_data(self, src_party, upstream_seq_id, curr_seq_id): + data_log_msg = f"data for {curr_seq_id} from {upstream_seq_id} of {src_party}" + logger.debug(f"Getting {data_log_msg}") + with self._lock: + if not key_exists_in_two_dim_dict( + self._events, upstream_seq_id, curr_seq_id + ): + add_two_dim_dict( + self._events, upstream_seq_id, curr_seq_id, asyncio.Event() + ) + curr_event = get_from_two_dim_dict(self._events, upstream_seq_id, curr_seq_id) + await curr_event.wait() + logger.debug(f"Waited {data_log_msg}.") + with self._lock: + data = pop_from_two_dim_dict(self._all_data, upstream_seq_id, curr_seq_id) + pop_from_two_dim_dict(self._events, upstream_seq_id, curr_seq_id) + + # NOTE(qwang): This is used to avoid the conflict with pickle5 in Ray. + import fed._private.serialization_utils as fed_ser_utils + fed_ser_utils._apply_loads_function_with_whitelist() + return cloudpickle.loads(data) + + async def _get_grpc_options(self): + return get_grpc_options() + + + + +class SendDataService(fed_pb2_grpc.GrpcServiceServicer): + def __init__(self, all_events, all_data, party, lock): + self._events = all_events + self._all_data = all_data + self._party = party + self._lock = lock + + async def SendData(self, request, context): + upstream_seq_id = request.upstream_seq_id + downstream_seq_id = request.downstream_seq_id + logger.debug( + f'Received a grpc data request from {upstream_seq_id} to ' + f'{downstream_seq_id}.' + ) + + with self._lock: + add_two_dim_dict( + self._all_data, upstream_seq_id, downstream_seq_id, request.data + ) + if not key_exists_in_two_dim_dict( + self._events, upstream_seq_id, downstream_seq_id + ): + event = asyncio.Event() + add_two_dim_dict( + self._events, upstream_seq_id, downstream_seq_id, event + ) + event = get_from_two_dim_dict(self._events, upstream_seq_id, downstream_seq_id) + event.set() + logger.debug(f"Event set for {upstream_seq_id}") + return fed_pb2.SendDataResponse(result="OK") + + +async def _run_grpc_server( + port, event, all_data, party, lock, + server_ready_future, tls_config=None, grpc_options=None +): + server = grpc.aio.server(options=grpc_options) + fed_pb2_grpc.add_GrpcServiceServicer_to_server( + SendDataService(event, all_data, party, lock), server + ) + + tls_enabled = fed_utils.tls_enabled(tls_config) + if tls_enabled: + ca_cert, private_key, cert_chain = fed_utils.load_cert_config(tls_config) + server_credentials = grpc.ssl_server_credentials( + [(private_key, cert_chain)], + root_certificates=ca_cert, + require_client_auth=ca_cert is not None, + ) + server.add_secure_port(f'[::]:{port}', server_credentials) + else: + server.add_insecure_port(f'[::]:{port}') + + msg = f"Succeeded to add port {port}." + await server.start() + logger.info( + f'Successfully start Grpc service with{"out" if not tls_enabled else ""} ' + 'credentials.' + ) + server_ready_future.set_result((True, msg)) + await server.wait_for_termination() + From 8ff56580ea5168b7de7b8326c72cd33fc56b8f9b Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 11 Jul 2023 01:13:19 +0800 Subject: [PATCH 09/22] parameter passng bugs --- fed/api.py | 11 ++++++++--- fed/config.py | 31 +++++++++++++++---------------- fed/proxy/barriers.py | 30 +++++++++++++++--------------- fed/proxy/grpc_proxy.py | 26 +++++++++++++++++--------- fed/tmp.py | 10 ++++++++++ tests/test_retry_policy.py | 4 ++-- 6 files changed, 67 insertions(+), 45 deletions(-) create mode 100644 fed/tmp.py diff --git a/fed/api.py b/fed/api.py index 184738f1..7375d16d 100644 --- a/fed/api.py +++ b/fed/api.py @@ -33,7 +33,8 @@ send, start_recv_proxy, start_send_proxy, - SendProxy + SendProxy, + RecvProxy ) from fed.config import CrossSiloCommConfig @@ -50,6 +51,7 @@ def init( logging_level: str = 'info', enable_waiting_for_other_parties_ready: bool = False, send_proxy_cls: SendProxy = None, + recv_proxy_cls: RecvProxy = None, global_cross_silo_comm_config: Optional[CrossSiloCommConfig] = None, **kwargs, ): @@ -182,20 +184,23 @@ def init( logger.info(f'Started rayfed with {cluster_config}') get_global_context().get_cleanup_manager().start( exit_when_failure_sending=global_cross_silo_comm_config.exit_on_sending_failure) + + if recv_proxy_cls is None: + from fed.proxy.grpc_proxy import GrpcRecvProxy + recv_proxy_cls = GrpcRecvProxy # Start recv proxy start_recv_proxy( cluster=cluster, party=party, logging_level=logging_level, tls_config=tls_config, - proxy_cls=None, + proxy_cls=recv_proxy_cls, proxy_config=global_cross_silo_comm_config ) if send_proxy_cls is None: from fed.proxy.grpc_proxy import GrpcSendProxy send_proxy_cls = GrpcSendProxy - start_send_proxy( cluster=cluster, party=party, diff --git a/fed/config.py b/fed/config.py index 9d3f53c3..d43be61e 100644 --- a/fed/config.py +++ b/fed/config.py @@ -85,21 +85,6 @@ class CrossSiloCommConfig: """A class to store parameters used for Proxy Actor Attributes: - grpc_retry_policy: a dict descibes the retry policy for - cross silo rpc call. If None, the following default retry policy - will be used. More details please refer to - `retry-policy `_. # noqa - - .. code:: python - { - "maxAttempts": 4, - "initialBackoff": "0.1s", - "maxBackoff": "1s", - "backoffMultiplier": 2, - "retryableStatusCodes": [ - "UNAVAILABLE" - ] - } proxier_fo_max_retries: The max restart times for the send proxy. serializing_allowed_list: The package or class list allowed for serializing(deserializating) cross silos. It's used for avoiding pickle @@ -156,7 +141,21 @@ class CrossSiloGrpcCommConfig(CrossSiloCommConfig): """A class to store parameters used for GRPC communication Attributes: - grpc_retry_policy: + grpc_retry_policy: a dict descibes the retry policy for + cross silo rpc call. If None, the following default retry policy + will be used. More details please refer to + `retry-policy `_. # noqa + + .. code:: python + { + "maxAttempts": 4, + "initialBackoff": "0.1s", + "maxBackoff": "1s", + "backoffMultiplier": 2, + "retryableStatusCodes": [ + "UNAVAILABLE" + ] + } grpc_channel_options: A list of tuples to store GRPC channel options, e.g. [ ('grpc.enable_retries', 1), diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index bebd1e40..a7279390 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -21,11 +21,9 @@ from typing import Dict, Optional import cloudpickle -import grpc import ray import fed.config as fed_config -import fed.utils as fed_utils from fed._private import constants from fed.config import get_job_config, CrossSiloCommConfig @@ -62,12 +60,14 @@ def pop_from_two_dim_dict(the_dict, key_a, key_b): class SendProxy(abc.ABC): def __init__( - self, - cluster: Dict, - party: str, - proxy_config = None) -> None: + self, + cluster: Dict, + party: str, + tls_config: Dict, + proxy_config = None) -> None: self._cluster = cluster self._party = party + self._tls_config = tls_config self._proxy_config = proxy_config @@ -81,7 +81,7 @@ async def send( ): pass - async def is_ready(): + async def is_ready(self): return True class RecvProxy(abc.ABC): @@ -135,10 +135,11 @@ def __init__( self._party = party self._tls_config = tls_config cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config - self.proxy_instance: SendProxy = proxy_cls(cluster, party, cross_silo_comm_config) + self.proxy_instance: SendProxy = proxy_cls(cluster, party, tls_config, cross_silo_comm_config) async def is_ready(self): - return self.proxy_instance.is_ready() + res = await self.proxy_instance.is_ready() + return res async def send( self, @@ -197,20 +198,19 @@ def __init__( self._party = party self._tls_config = tls_config cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config - set_max_message_length(cross_silo_comm_config.messages_max_size_in_bytes) - self._proxy_instance: RecvProxy = proxy_cls(listen_addr, party, cross_silo_comm_config) + self._proxy_instance: RecvProxy = proxy_cls(listen_addr, party, tls_config, cross_silo_comm_config) async def start(self): await self._proxy_instance.start() - - async def is_ready(self): - return self._proxy_instance.is_ready() + async def is_ready(self): + res = await self._proxy_instance.is_ready() + return res async def get_data(self, src_party, upstream_seq_id, curr_seq_id): self._stats["receive_op_count"] += 1 data = await self._proxy_instance.get_data(src_party, upstream_seq_id, curr_seq_id) - return cloudpickle.loads(data) + return data async def _get_stats(self): return self._stats diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc_proxy.py index 9124c079..3cc88c01 100644 --- a/fed/proxy/grpc_proxy.py +++ b/fed/proxy/grpc_proxy.py @@ -27,10 +27,13 @@ class GrpcSendProxy(SendProxy): - def __init__(self, cluster: Dict, party: str, proxy_config=None) -> None: - super().__init__(cluster, party, proxy_config) + def __init__(self, cluster: Dict, party: str, tls_config: Dict, proxy_config=None) -> None: + super().__init__(cluster, party, tls_config, proxy_config) self._grpc_metadata = proxy_config.http_header set_max_message_length(proxy_config.messages_max_size_in_bytes) + self._retry_policy = None + if isinstance(proxy_config, CrossSiloGrpcCommConfig): + self._retry_policy = proxy_config.grpc_retry_policy # Mapping the destination party name to the reused client stub. self._stubs = {} @@ -44,7 +47,7 @@ async def send( dest_party_grpc_config = self.setup_grpc_config(dest_party) tls_enabled = fed_utils.tls_enabled(self._tls_config) grpc_options = dest_party_grpc_config['grpc_options'] - grpc_options = get_grpc_options(retry_policy=self.retry_policy) if \ + grpc_options = get_grpc_options(retry_policy=self._retry_policy) if \ grpc_options is None else fed_utils.dict2tuple(grpc_options) if dest_party not in self._stubs: if tls_enabled: @@ -62,11 +65,13 @@ async def send( stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub + timeout = self._proxy_config.timeout_in_seconds response = await send_data_grpc( data=data, stub=self._stubs[dest_party], upstream_seq_id=upstream_seq_id, downstream_seq_id=downstream_seq_id, + timeout=timeout, metadata=dest_party_grpc_config['grpc_metadata'], ) return response @@ -83,7 +88,7 @@ def setup_grpc_config(self, dest_party): dest_party_grpc_config['grpc_metadata'] = { **global_grpc_metadata, **dest_party_grpc_metadata} - global_grpc_options = dict(get_grpc_options(self.retry_policy)) + global_grpc_options = dict(get_grpc_options(self._retry_policy)) dest_party_grpc_options = dict( self._cluster[dest_party].get('grpc_options', {}) ) @@ -101,9 +106,10 @@ async def send_data_grpc( stub, upstream_seq_id, downstream_seq_id, + timeout, metadata=None, ): - cluster_config = fed_config.get_cluster_config() + job_config = fed_config.get_job_config() data = cloudpickle.dumps(data) request = fed_pb2.SendDataRequest( data=data, @@ -114,7 +120,7 @@ async def send_data_grpc( response = await stub.SendData( request, metadata=fed_utils.dict2tuple(metadata), - timeout=cluster_config.cross_silo_timeout, + timeout=timeout, ) logger.debug( f'Received data response from seq_id {downstream_seq_id}, ' @@ -124,8 +130,9 @@ async def send_data_grpc( class GrpcRecvProxy(RecvProxy): - def __init__(self, listen_addr: str, party: str, proxy_config: CrossSiloCommConfig) -> None: - super().__init__(listen_addr, party, proxy_config) + def __init__(self, listen_addr: str, party: str, tls_config: Dict, proxy_config: CrossSiloCommConfig) -> None: + super().__init__(listen_addr, party, tls_config, proxy_config) + set_max_message_length(proxy_config.messages_max_size_in_bytes) # Flag to see whether grpc server starts self._server_ready_future = asyncio.Future() self._retry_policy = None @@ -159,7 +166,8 @@ async def start(self): async def is_ready(self): await self._server_ready_future - return self._server_ready_future.result() + res = self._server_ready_future.result() + return res async def get_data(self, src_party, upstream_seq_id, curr_seq_id): data_log_msg = f"data for {curr_seq_id} from {upstream_seq_id} of {src_party}" diff --git a/fed/tmp.py b/fed/tmp.py new file mode 100644 index 00000000..24455aaa --- /dev/null +++ b/fed/tmp.py @@ -0,0 +1,10 @@ +import ray +import fed +ray.init() + +cluster = { + 'alice': {'address': '127.0.0.1:11012'}, + 'bob': {'address': '127.0.0.1:11011'}, +} +party = 'alice' +fed.init(cluster, party) \ No newline at end of file diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 9fa5f101..c2a45d8b 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -20,7 +20,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloCommConfig +from fed.config import CrossSiloGrpcCommConfig @fed.remote @@ -53,7 +53,7 @@ def run(party, is_inner_party): fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloCommConfig( + cross_silo_comm_config=CrossSiloGrpcCommConfig( grpc_retry_policy=retry_policy ) ) From 582f5239f6fbe19681d267484a5f0c2b16dbfafe Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 11 Jul 2023 01:20:34 +0800 Subject: [PATCH 10/22] lint codes --- fed/api.py | 7 +++--- fed/config.py | 6 ----- fed/proxy/barriers.py | 33 ++++++++++++------------- fed/proxy/grpc_proxy.py | 54 ++++++++++++++++++++++------------------- fed/tmp.py | 10 -------- 5 files changed, 49 insertions(+), 61 deletions(-) delete mode 100644 fed/tmp.py diff --git a/fed/api.py b/fed/api.py index 7375d16d..39ad71fb 100644 --- a/fed/api.py +++ b/fed/api.py @@ -153,7 +153,8 @@ def init( 'cert' in tls_config and 'key' in tls_config ), 'Cert or key are not in tls_config.' - global_cross_silo_comm_config = global_cross_silo_comm_config or CrossSiloCommConfig() + global_cross_silo_comm_config = \ + global_cross_silo_comm_config or CrossSiloCommConfig() # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv() @@ -184,7 +185,7 @@ def init( logger.info(f'Started rayfed with {cluster_config}') get_global_context().get_cleanup_manager().start( exit_when_failure_sending=global_cross_silo_comm_config.exit_on_sending_failure) - + if recv_proxy_cls is None: from fed.proxy.grpc_proxy import GrpcRecvProxy recv_proxy_cls = GrpcRecvProxy @@ -207,7 +208,7 @@ def init( logging_level=logging_level, tls_config=tls_config, proxy_cls=send_proxy_cls, - proxy_config=global_cross_silo_comm_config # retry_policy=cross_silo_comm_config.grpc_retry_policy, + proxy_config=global_cross_silo_comm_config ) if enable_waiting_for_other_parties_ready: diff --git a/fed/config.py b/fed/config.py index d43be61e..f171ce80 100644 --- a/fed/config.py +++ b/fed/config.py @@ -176,12 +176,6 @@ class CrossSiloBRPCConfig(CrossSiloCommConfig): """A class to store parameters used for GRPC communication Attributes: - grpc_retry_policy: - grpc_channel_options: A list of tuples to store GRPC channel options, - e.g. [ - ('grpc.enable_retries', 1), - ('grpc.max_send_message_length', 50 * 1024 * 1024) - ] """ def __init__(self, brpc_options, diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index a7279390..90c7c966 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -13,14 +13,11 @@ # limitations under the License. import abc -import asyncio import logging -import threading import time import copy from typing import Dict, Optional -import cloudpickle import ray import fed.config as fed_config @@ -64,13 +61,13 @@ def __init__( cluster: Dict, party: str, tls_config: Dict, - proxy_config = None) -> None: + proxy_config=None + ) -> None: self._cluster = cluster self._party = party self._tls_config = tls_config self._proxy_config = proxy_config - @abc.abstractmethod async def send( self, @@ -84,19 +81,20 @@ async def send( async def is_ready(self): return True + class RecvProxy(abc.ABC): def __init__( self, listen_addr: str, party: str, tls_config: Dict, - proxy_config: CrossSiloCommConfig) -> None: + proxy_config: CrossSiloCommConfig + ) -> None: self._listen_addr = listen_addr self._party = party self._tls_config = tls_config self._proxy_config = proxy_config - @abc.abstractmethod def start(self): pass @@ -121,7 +119,7 @@ def __init__( party: str, tls_config: Dict = None, logging_level: str = None, - proxy_cls = None + proxy_cls=None ): setup_logger( logging_level=logging_level, @@ -135,7 +133,8 @@ def __init__( self._party = party self._tls_config = tls_config cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config - self.proxy_instance: SendProxy = proxy_cls(cluster, party, tls_config, cross_silo_comm_config) + self.proxy_instance: SendProxy = proxy_cls( + cluster, party, tls_config, cross_silo_comm_config) async def is_ready(self): res = await self.proxy_instance.is_ready() @@ -161,18 +160,17 @@ async def send( ' credentials.' ) try: - response = await self.proxy_instance.send(dest_party, data, upstream_seq_id, downstream_seq_id) + response = await self.proxy_instance.send( + dest_party, data, upstream_seq_id, downstream_seq_id) except Exception as e: logger.error(f'Failed to {send_log_msg}, error: {e}') return False logger.debug(f"Succeeded to send {send_log_msg}. Response is {response}") return True # True indicates it's sent successfully. - async def _get_stats(self): return self._stats - async def _get_cluster_info(self): return self._cluster @@ -185,7 +183,7 @@ def __init__( party: str, logging_level: str, tls_config=None, - proxy_cls = None, + proxy_cls=None, ): setup_logger( logging_level=logging_level, @@ -198,7 +196,8 @@ def __init__( self._party = party self._tls_config = tls_config cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config - self._proxy_instance: RecvProxy = proxy_cls(listen_addr, party, tls_config, cross_silo_comm_config) + self._proxy_instance: RecvProxy = proxy_cls( + listen_addr, party, tls_config, cross_silo_comm_config) async def start(self): await self._proxy_instance.start() @@ -208,15 +207,15 @@ async def is_ready(self): return res async def get_data(self, src_party, upstream_seq_id, curr_seq_id): - self._stats["receive_op_count"] += 1 - data = await self._proxy_instance.get_data(src_party, upstream_seq_id, curr_seq_id) + self._stats["receive_op_count"] += 1 + data = await self._proxy_instance.get_data( + src_party, upstream_seq_id, curr_seq_id) return data async def _get_stats(self): return self._stats - _DEFAULT_RECV_PROXY_OPTIONS = { "max_concurrency": 1000, } diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc_proxy.py index 3cc88c01..be7530a8 100644 --- a/fed/proxy/grpc_proxy.py +++ b/fed/proxy/grpc_proxy.py @@ -6,7 +6,6 @@ from typing import Dict -import fed.config as fed_config import fed.utils as fed_utils from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig @@ -25,9 +24,14 @@ logger = logging.getLogger(__name__) - class GrpcSendProxy(SendProxy): - def __init__(self, cluster: Dict, party: str, tls_config: Dict, proxy_config=None) -> None: + def __init__( + self, + cluster: Dict, + party: str, + tls_config: Dict, + proxy_config=None + ) -> None: super().__init__(cluster, party, tls_config, proxy_config) self._grpc_metadata = proxy_config.http_header set_max_message_length(proxy_config.messages_max_size_in_bytes) @@ -50,20 +54,20 @@ async def send( grpc_options = get_grpc_options(retry_policy=self._retry_policy) if \ grpc_options is None else fed_utils.dict2tuple(grpc_options) if dest_party not in self._stubs: - if tls_enabled: - ca_cert, private_key, cert_chain = fed_utils.load_cert_config( - self._tls_config) - credentials = grpc.ssl_channel_credentials( - certificate_chain=cert_chain, - private_key=private_key, - root_certificates=ca_cert, - ) - channel = grpc.aio.secure_channel( - dest_addr, credentials, options=grpc_options) - else: - channel = grpc.aio.insecure_channel(dest_addr, options=grpc_options) - stub = fed_pb2_grpc.GrpcServiceStub(channel) - self._stubs[dest_party] = stub + if tls_enabled: + ca_cert, private_key, cert_chain = fed_utils.load_cert_config( + self._tls_config) + credentials = grpc.ssl_channel_credentials( + certificate_chain=cert_chain, + private_key=private_key, + root_certificates=ca_cert, + ) + channel = grpc.aio.secure_channel( + dest_addr, credentials, options=grpc_options) + else: + channel = grpc.aio.insecure_channel(dest_addr, options=grpc_options) + stub = fed_pb2_grpc.GrpcServiceStub(channel) + self._stubs[dest_party] = stub timeout = self._proxy_config.timeout_in_seconds response = await send_data_grpc( @@ -75,7 +79,7 @@ async def send( metadata=dest_party_grpc_config['grpc_metadata'], ) return response - + def setup_grpc_config(self, dest_party): dest_party_grpc_config = {} global_grpc_metadata = ( @@ -100,7 +104,6 @@ async def _get_grpc_options(self): return get_grpc_options() - async def send_data_grpc( data, stub, @@ -109,7 +112,6 @@ async def send_data_grpc( timeout, metadata=None, ): - job_config = fed_config.get_job_config() data = cloudpickle.dumps(data) request = fed_pb2.SendDataRequest( data=data, @@ -130,7 +132,13 @@ async def send_data_grpc( class GrpcRecvProxy(RecvProxy): - def __init__(self, listen_addr: str, party: str, tls_config: Dict, proxy_config: CrossSiloCommConfig) -> None: + def __init__( + self, + listen_addr: str, + party: str, + tls_config: Dict, + proxy_config: CrossSiloCommConfig + ) -> None: super().__init__(listen_addr, party, tls_config, proxy_config) set_max_message_length(proxy_config.messages_max_size_in_bytes) # Flag to see whether grpc server starts @@ -163,7 +171,6 @@ async def start(self): f' when calling `fed.init`. Grpc error msg: {err}' self._server_ready_future.set_result((False, msg)) - async def is_ready(self): await self._server_ready_future res = self._server_ready_future.result() @@ -195,8 +202,6 @@ async def _get_grpc_options(self): return get_grpc_options() - - class SendDataService(fed_pb2_grpc.GrpcServiceServicer): def __init__(self, all_events, all_data, party, lock): self._events = all_events @@ -258,4 +263,3 @@ async def _run_grpc_server( ) server_ready_future.set_result((True, msg)) await server.wait_for_termination() - diff --git a/fed/tmp.py b/fed/tmp.py deleted file mode 100644 index 24455aaa..00000000 --- a/fed/tmp.py +++ /dev/null @@ -1,10 +0,0 @@ -import ray -import fed -ray.init() - -cluster = { - 'alice': {'address': '127.0.0.1:11012'}, - 'bob': {'address': '127.0.0.1:11011'}, -} -party = 'alice' -fed.init(cluster, party) \ No newline at end of file From cdc199683f5f03ee51dc502b4de6a1f1d8a76b66 Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 11 Jul 2023 01:52:25 +0800 Subject: [PATCH 11/22] default val for Grpc config --- tests/test_exit_on_failure_sending.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 4b665d4e..a433df14 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -19,7 +19,7 @@ import fed import fed._private.compatible_utils as compatible_utils -from fed.config import CrossSiloCommConfig +from fed.config import CrossSiloGrpcCommConfig import signal @@ -63,7 +63,7 @@ def run(party, is_inner_party): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } - cross_silo_comm_config = CrossSiloCommConfig( + cross_silo_comm_config = CrossSiloGrpcCommConfig( grpc_retry_policy=retry_policy, exit_on_sending_failure=True ) @@ -90,4 +90,5 @@ def test_exit_when_failure_on_sending(): if __name__ == "__main__": - sys.exit(pytest.main(["-sv", __file__])) + # sys.exit(pytest.main(["-sv", __file__])) + test_exit_when_failure_on_sending() From b5a3f08bcfedc6a1dfb9f9d1eeba4bebcdd3ae93 Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 11 Jul 2023 01:52:42 +0800 Subject: [PATCH 12/22] default val for Grpc config --- fed/config.py | 6 +++--- fed/proxy/barriers.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fed/config.py b/fed/config.py index f171ce80..5689c03c 100644 --- a/fed/config.py +++ b/fed/config.py @@ -7,7 +7,7 @@ import fed._private.compatible_utils as compatible_utils import fed._private.constants as fed_constants import cloudpickle -from typing import Dict, Optional +from typing import Dict, List, Optional import json @@ -163,8 +163,8 @@ class CrossSiloGrpcCommConfig(CrossSiloCommConfig): ] """ def __init__(self, - grpc_channel_options, - grpc_retry_policy, + grpc_channel_options: List = None, + grpc_retry_policy: Dict[str, str] = None, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 90c7c966..5f4063a8 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -254,7 +254,7 @@ def start_recv_proxy( proxy_cls=proxy_cls ) recver_proxy_actor.start.remote() - timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds + timeout = proxy_config.timeout_in_seconds if proxy_config is not None else 60 server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout) assert server_state[0], server_state[1] logger.info("RecverProxy has successfully created.") From a432d31d7e926aab63147d14e625fafaed1801b2 Mon Sep 17 00:00:00 2001 From: paer Date: Wed, 12 Jul 2023 14:53:19 +0800 Subject: [PATCH 13/22] union grpc_options --- fed/_private/grpc_options.py | 56 ++++++------- fed/api.py | 7 +- fed/proxy/barriers.py | 19 ++++- fed/proxy/grpc_proxy.py | 114 ++++++++++++++++---------- tests/test_exit_on_failure_sending.py | 3 +- tests/test_grpc_options_on_proxies.py | 4 +- 6 files changed, 119 insertions(+), 84 deletions(-) diff --git a/fed/_private/grpc_options.py b/fed/_private/grpc_options.py index 9b7d30fd..6e4b2d14 100644 --- a/fed/_private/grpc_options.py +++ b/fed/_private/grpc_options.py @@ -14,7 +14,10 @@ import json -_GRPC_RETRY_POLICY = { + +_GRPC_SERVICE = "GrpcService" + +_DEFAULT_GRPC_RETRY_POLICY = { "maxAttempts": 5, "initialBackoff": "5s", "maxBackoff": "30s", @@ -22,49 +25,38 @@ "retryableStatusCodes": ["UNAVAILABLE"], } -_GRPC_SERVICE = "GrpcService" _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH = 500 * 1024 * 1024 _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH = 500 * 1024 * 1024 -_GRPC_MAX_SEND_MESSAGE_LENGTH = _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH -_GRPC_MAX_RECEIVE_MESSAGE_LENGTH = _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH - - -def set_max_message_length(max_size_in_bytes): - """Set the maximum length in bytes of gRPC messages. - - NOTE: The default maximum length is 500MB(500 * 1024 * 1024) - """ - global _GRPC_MAX_SEND_MESSAGE_LENGTH - global _GRPC_MAX_RECEIVE_MESSAGE_LENGTH - if not max_size_in_bytes: - return - if max_size_in_bytes < 0: - raise ValueError("Negative max size is not allowed") - _GRPC_MAX_SEND_MESSAGE_LENGTH = max_size_in_bytes - _GRPC_MAX_RECEIVE_MESSAGE_LENGTH = max_size_in_bytes - - -def get_grpc_max_send_message_length(): - global _GRPC_MAX_SEND_MESSAGE_LENGTH - return _GRPC_MAX_SEND_MESSAGE_LENGTH - - -def get_grpc_max_recieve_message_length(): - global _GRPC_MAX_SEND_MESSAGE_LENGTH - return _GRPC_MAX_SEND_MESSAGE_LENGTH +_DEFAULT_GRPC_CHANNEL_OPTIONS = { + 'grpc.enable_retries': 1, + 'grpc.so_reuseport': 0, + 'grpc.max_send_message_length': _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH, + 'grpc.max_receive_message_length': _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH, + 'grpc.service_config': + json.dumps( + { + 'methodConfig': [ + { + 'name': [{'service': _GRPC_SERVICE}], + 'retryPolicy': _DEFAULT_GRPC_RETRY_POLICY, + } + ] + } + ), +} def get_grpc_options( retry_policy=None, max_send_message_length=None, max_receive_message_length=None ): if not retry_policy: - retry_policy = _GRPC_RETRY_POLICY + retry_policy = _DEFAULT_GRPC_RETRY_POLICY if not max_send_message_length: - max_send_message_length = get_grpc_max_send_message_length() + max_send_message_length = _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH if not max_receive_message_length: - max_receive_message_length = get_grpc_max_recieve_message_length() + max_receive_message_length = _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH return [ ( diff --git a/fed/api.py b/fed/api.py index 39ad71fb..4a62aafb 100644 --- a/fed/api.py +++ b/fed/api.py @@ -69,7 +69,8 @@ def init( # (Optional) the listen address, the `address` will be # used if not provided. 'listen_addr': '0.0.0.0:10001', - # (Optional) The party specific metadata sent with the grpc request + 'cross_silo_comm_config': CrossSiloCommConfig + # (Optional) The party specific metadata sent with grpc requests 'grpc_metadata': (('token', 'alice-token'),), 'grpc_options': [ ('grpc.default_authority', 'alice'), @@ -82,7 +83,7 @@ def init( # (Optional) the listen address, the `address` will be # used if not provided. 'listen_addr': '0.0.0.0:10002', - # (Optional) The party specific metadata sent with the grpc request + # (Optional) The party specific metadata sent with grpc requests 'grpc_metadata': (('token', 'bob-token'),), }, 'carol': { @@ -91,7 +92,7 @@ def init( # (Optional) the listen address, the `address` will be # used if not provided. 'listen_addr': '0.0.0.0:10003', - # (Optional) The party specific metadata sent with the grpc request + # (Optional) The party specific metadata sent with grpc requests 'grpc_metadata': (('token', 'carol-token'),), }, } diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 5f4063a8..b711695f 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -61,7 +61,7 @@ def __init__( cluster: Dict, party: str, tls_config: Dict, - proxy_config=None + proxy_config: CrossSiloCommConfig=None ) -> None: self._cluster = cluster self._party = party @@ -81,6 +81,9 @@ async def send( async def is_ready(self): return True + async def get_proxy_config(self): + return self._proxy_config + class RecvProxy(abc.ABC): def __init__( @@ -110,6 +113,9 @@ async def get_data( async def is_ready(self): return True + async def get_proxy_config(self): + return self._proxy_config + @ray.remote class SendProxyActor: @@ -133,11 +139,11 @@ def __init__( self._party = party self._tls_config = tls_config cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config - self.proxy_instance: SendProxy = proxy_cls( + self._proxy_instance: SendProxy = proxy_cls( cluster, party, tls_config, cross_silo_comm_config) async def is_ready(self): - res = await self.proxy_instance.is_ready() + res = await self._proxy_instance.is_ready() return res async def send( @@ -160,7 +166,7 @@ async def send( ' credentials.' ) try: - response = await self.proxy_instance.send( + response = await self._proxy_instance.send( dest_party, data, upstream_seq_id, downstream_seq_id) except Exception as e: logger.error(f'Failed to {send_log_msg}, error: {e}') @@ -174,6 +180,8 @@ async def _get_stats(self): async def _get_cluster_info(self): return self._cluster + async def _get_proxy_config(self): + return await self._proxy_instance.get_proxy_config() @ray.remote class RecverProxyActor: @@ -215,6 +223,9 @@ async def get_data(self, src_party, upstream_seq_id, curr_seq_id): async def _get_stats(self): return self._stats + async def _get_proxy_config(self): + return await self._proxy_instance.get_proxy_config() + _DEFAULT_RECV_PROXY_OPTIONS = { "max_concurrency": 1000, diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc_proxy.py index be7530a8..03a20fd7 100644 --- a/fed/proxy/grpc_proxy.py +++ b/fed/proxy/grpc_proxy.py @@ -1,15 +1,17 @@ import asyncio +import copy import cloudpickle import grpc import logging import threading +import json from typing import Dict import fed.utils as fed_utils from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig -from fed._private.grpc_options import get_grpc_options, set_max_message_length +from fed._private.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE from fed.proxy.barriers import ( add_two_dim_dict, get_from_two_dim_dict, @@ -24,20 +26,47 @@ logger = logging.getLogger(__name__) +def parse_grpc_options(proxy_config: CrossSiloCommConfig): + grpc_channel_options = {} + if proxy_config is not None: + if proxy_config.messages_max_size_in_bytes is not None: + grpc_channel_options.update({ + 'grpc.max_send_message_length': + proxy_config.messages_max_size_in_bytes, + 'grpc.max_receive_message_length': + proxy_config.messages_max_size_in_bytes + }) + if isinstance(proxy_config, CrossSiloGrpcCommConfig): + grpc_channel_options.update(proxy_config.grpc_channel_options) + if proxy_config.grpc_retry_policy is not None: + grpc_channel_options.update({ + json.dumps( + { + 'methodConfig': [ + { + 'name': [{'service': _GRPC_SERVICE}], + 'retryPolicy': proxy_config.grpc_retry_policy, + } + ] + } + ), + }) + + return grpc_channel_options + + class GrpcSendProxy(SendProxy): def __init__( self, cluster: Dict, party: str, tls_config: Dict, - proxy_config=None + proxy_config: CrossSiloCommConfig=None ) -> None: super().__init__(cluster, party, tls_config, proxy_config) - self._grpc_metadata = proxy_config.http_header - set_max_message_length(proxy_config.messages_max_size_in_bytes) - self._retry_policy = None - if isinstance(proxy_config, CrossSiloGrpcCommConfig): - self._retry_policy = proxy_config.grpc_retry_policy + self._grpc_metadata = proxy_config.http_header or {} + self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) + self._grpc_options.update(parse_grpc_options(self._proxy_config)) # Mapping the destination party name to the reused client stub. self._stubs = {} @@ -48,11 +77,8 @@ async def send( upstream_seq_id, downstream_seq_id): dest_addr = self._cluster[dest_party]['address'] - dest_party_grpc_config = self.setup_grpc_config(dest_party) + grpc_metadata, grpc_channel_options = self.get_grpc_config_by_party(dest_party) tls_enabled = fed_utils.tls_enabled(self._tls_config) - grpc_options = dest_party_grpc_config['grpc_options'] - grpc_options = get_grpc_options(retry_policy=self._retry_policy) if \ - grpc_options is None else fed_utils.dict2tuple(grpc_options) if dest_party not in self._stubs: if tls_enabled: ca_cert, private_key, cert_chain = fed_utils.load_cert_config( @@ -63,9 +89,9 @@ async def send( root_certificates=ca_cert, ) channel = grpc.aio.secure_channel( - dest_addr, credentials, options=grpc_options) + dest_addr, credentials, options=grpc_channel_options) else: - channel = grpc.aio.insecure_channel(dest_addr, options=grpc_options) + channel = grpc.aio.insecure_channel(dest_addr, options=grpc_channel_options) stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub @@ -76,33 +102,36 @@ async def send( upstream_seq_id=upstream_seq_id, downstream_seq_id=downstream_seq_id, timeout=timeout, - metadata=dest_party_grpc_config['grpc_metadata'], + metadata=grpc_metadata, ) return response - def setup_grpc_config(self, dest_party): - dest_party_grpc_config = {} - global_grpc_metadata = ( - dict(self._grpc_metadata) if self._grpc_metadata is not None else {} - ) - dest_party_grpc_metadata = dict( - self._cluster[dest_party].get('grpc_metadata', {}) - ) - # merge grpc metadata - dest_party_grpc_config['grpc_metadata'] = { - **global_grpc_metadata, **dest_party_grpc_metadata} - - global_grpc_options = dict(get_grpc_options(self._retry_policy)) - dest_party_grpc_options = dict( - self._cluster[dest_party].get('grpc_options', {}) - ) - dest_party_grpc_config['grpc_options'] = { - **global_grpc_options, **dest_party_grpc_options} - return dest_party_grpc_config + def get_grpc_config_by_party(self, dest_party): + """Overide global config by party specific config + """ + grpc_metadata = self._grpc_metadata + grpc_options = self._grpc_options - async def _get_grpc_options(self): - return get_grpc_options() + dest_party_comm_config = self._cluster[dest_party].get( + 'cross_silo_comm_config', None) + if dest_party_comm_config is not None: + if dest_party_comm_config.http_header is not None: + dest_party_grpc_metadata = dict(dest_party_comm_config.http_header) + grpc_metadata = { + **grpc_metadata, + **dest_party_grpc_metadata + } + dest_party_grpc_options = parse_grpc_options(dest_party_comm_config) + grpc_options = fed_utils.dict2tuple({ + **grpc_options, **dest_party_grpc_options + }) + return grpc_metadata, grpc_options + async def get_proxy_config(self): + proxy_config = self._proxy_config.__dict__ + proxy_config.update({'grpc_options': fed_utils.dict2tuple(self._grpc_options)}) + return proxy_config + async def send_data_grpc( data, @@ -140,12 +169,11 @@ def __init__( proxy_config: CrossSiloCommConfig ) -> None: super().__init__(listen_addr, party, tls_config, proxy_config) - set_max_message_length(proxy_config.messages_max_size_in_bytes) + self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) + self._grpc_options.update(parse_grpc_options(self._proxy_config)) + # Flag to see whether grpc server starts self._server_ready_future = asyncio.Future() - self._retry_policy = None - if isinstance(proxy_config, CrossSiloGrpcCommConfig): - self._retry_policy = proxy_config.grpc_retry_policy # All events for grpc waitting usage. self._events = {} # map from (upstream_seq_id, downstream_seq_id) to event @@ -163,7 +191,7 @@ async def start(self): self._lock, self._server_ready_future, self._tls_config, - get_grpc_options(self._retry_policy), + self._grpc_options, ) except RuntimeError as err: msg = f'Grpc server failed to listen to port: {port}' \ @@ -198,8 +226,10 @@ async def get_data(self, src_party, upstream_seq_id, curr_seq_id): fed_ser_utils._apply_loads_function_with_whitelist() return cloudpickle.loads(data) - async def _get_grpc_options(self): - return get_grpc_options() + async def get_proxy_config(self): + proxy_config = self._proxy_config.__dict__ + proxy_config.update({'grpc_options': fed_utils.dict2tuple(self._grpc_options)}) + return proxy_config class SendDataService(fed_pb2_grpc.GrpcServiceServicer): diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index a433df14..3f6ed529 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -90,5 +90,4 @@ def test_exit_when_failure_on_sending(): if __name__ == "__main__": - # sys.exit(pytest.main(["-sv", __file__])) - test_exit_when_failure_on_sending() + sys.exit(pytest.main(["-sv", __file__])) diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index b3f1b057..cf2ebb3e 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -40,7 +40,9 @@ def run(party): ) def _assert_on_proxy(proxy_actor): - options = ray.get(proxy_actor._get_grpc_options.remote()) + config = ray.get(proxy_actor._get_proxy_config.remote()) + print(f"==============={config}==============") + options = config['grpc_options'] assert options[0][0] == "grpc.max_send_message_length" assert options[0][1] == 100 assert ('grpc.so_reuseport', 0) in options From 7ffe3673ff2225099f0d14d1d919f9a90e6f0633 Mon Sep 17 00:00:00 2001 From: paer Date: Wed, 12 Jul 2023 17:45:21 +0800 Subject: [PATCH 14/22] fix retry_policy update & get party grpc_options Signed-off-by: paer --- fed/proxy/barriers.py | 6 +- fed/proxy/grpc_proxy.py | 24 +++-- .../test_unpickle_with_whitelist.py | 2 +- tests/test_exit_on_failure_sending.py | 2 +- tests/test_grpc_options_on_proxies.py | 6 +- tests/test_grpc_options_per_party.py | 96 +++++++++++++++---- tests/test_party_specific_grpc_options.py | 78 --------------- tests/test_retry_policy.py | 2 +- tests/test_setup_proxy_actor.py | 2 +- 9 files changed, 101 insertions(+), 117 deletions(-) delete mode 100644 tests/test_party_specific_grpc_options.py diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index b711695f..7e862770 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -81,7 +81,7 @@ async def send( async def is_ready(self): return True - async def get_proxy_config(self): + async def get_proxy_config(self, dest_party=None): return self._proxy_config @@ -180,8 +180,8 @@ async def _get_stats(self): async def _get_cluster_info(self): return self._cluster - async def _get_proxy_config(self): - return await self._proxy_instance.get_proxy_config() + async def _get_proxy_config(self, dest_party=None): + return await self._proxy_instance.get_proxy_config(dest_party) @ray.remote class RecverProxyActor: diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc_proxy.py index 03a20fd7..233c0338 100644 --- a/fed/proxy/grpc_proxy.py +++ b/fed/proxy/grpc_proxy.py @@ -37,9 +37,11 @@ def parse_grpc_options(proxy_config: CrossSiloCommConfig): proxy_config.messages_max_size_in_bytes }) if isinstance(proxy_config, CrossSiloGrpcCommConfig): - grpc_channel_options.update(proxy_config.grpc_channel_options) + if proxy_config.grpc_channel_options is not None: + grpc_channel_options.update(proxy_config.grpc_channel_options) if proxy_config.grpc_retry_policy is not None: grpc_channel_options.update({ + 'grpc.service_config': json.dumps( { 'methodConfig': [ @@ -122,16 +124,20 @@ def get_grpc_config_by_party(self, dest_party): **dest_party_grpc_metadata } dest_party_grpc_options = parse_grpc_options(dest_party_comm_config) - grpc_options = fed_utils.dict2tuple({ + grpc_options = { **grpc_options, **dest_party_grpc_options - }) - return grpc_metadata, grpc_options - - async def get_proxy_config(self): + } + return grpc_metadata, fed_utils.dict2tuple(grpc_options) + + async def get_proxy_config(self, dest_party=None): + if dest_party is None: + grpc_options = fed_utils.dict2tuple(self._grpc_options) + else: + _, grpc_options = self.get_grpc_config_by_party(dest_party) proxy_config = self._proxy_config.__dict__ - proxy_config.update({'grpc_options': fed_utils.dict2tuple(self._grpc_options)}) + proxy_config.update({'grpc_options': grpc_options}) return proxy_config - + async def send_data_grpc( data, @@ -191,7 +197,7 @@ async def start(self): self._lock, self._server_ready_future, self._tls_config, - self._grpc_options, + fed_utils.dict2tuple(self._grpc_options), ) except RuntimeError as err: msg = f'Grpc server failed to listen to port: {port}' \ diff --git a/tests/serializations_tests/test_unpickle_with_whitelist.py b/tests/serializations_tests/test_unpickle_with_whitelist.py index 28b4ca81..0cf75e66 100644 --- a/tests/serializations_tests/test_unpickle_with_whitelist.py +++ b/tests/serializations_tests/test_unpickle_with_whitelist.py @@ -53,7 +53,7 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloCommConfig( serializing_allowed_list=allowed_list )) diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 3f6ed529..555e7557 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -71,7 +71,7 @@ def run(party, is_inner_party): cluster=cluster, party=party, logging_level='debug', - cross_silo_comm_config=cross_silo_comm_config + global_cross_silo_comm_config=cross_silo_comm_config ) o = f.party("alice").remote() diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index cf2ebb3e..b70bf0e6 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -35,16 +35,14 @@ def run(party): fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloCommConfig( messages_max_size_in_bytes=100) ) def _assert_on_proxy(proxy_actor): config = ray.get(proxy_actor._get_proxy_config.remote()) - print(f"==============={config}==============") options = config['grpc_options'] - assert options[0][0] == "grpc.max_send_message_length" - assert options[0][1] == 100 + assert ("grpc.max_send_message_length", 100) in options assert ('grpc.so_reuseport', 0) in options send_proxy = ray.get_actor("SendProxyActor") diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py index e99d3d9a..53b70cbd 100644 --- a/tests/test_grpc_options_per_party.py +++ b/tests/test_grpc_options_per_party.py @@ -18,7 +18,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloCommConfig +from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig @fed.remote @@ -31,39 +31,34 @@ def run(party): cluster = { 'alice': { 'address': '127.0.0.1:11010', - 'grpc_options': [ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 200) - ] + 'cross_silo_comm_config': CrossSiloGrpcCommConfig( + grpc_channel_options=[ + ('grpc.default_authority', 'alice'), + ('grpc.max_send_message_length', 200) + ]) }, 'bob': {'address': '127.0.0.1:11011'}, } fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloCommConfig( messages_max_size_in_bytes=100) ) def _assert_on_send_proxy(proxy_actor): - alice_config = ray.get(proxy_actor.setup_grpc_config.remote('alice')) + alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) # print(f"【NKcqx】alice config: {alice_config}") assert 'grpc_options' in alice_config alice_options = alice_config['grpc_options'] - assert 'grpc.max_send_message_length' in alice_options - # This should be overwritten by cluster config - assert alice_options['grpc.max_send_message_length'] == 200 - assert 'grpc.default_authority' in alice_options - assert alice_options['grpc.default_authority'] == 'alice' - - bob_config = ray.get(proxy_actor.setup_grpc_config.remote('bob')) - # print(f"【NKcqx】bob config: {bob_config}") + assert ('grpc.max_send_message_length', 200) in alice_options + assert ('grpc.default_authority', 'alice') in alice_options + + bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob')) assert 'grpc_options' in bob_config bob_options = bob_config['grpc_options'] - assert "grpc.max_send_message_length" in bob_options - # Not setting bob's grpc_options, should be the same with global - assert bob_options["grpc.max_send_message_length"] == 100 - assert 'grpc.default_authority' not in bob_options + assert ('grpc.max_send_message_length', 100) in bob_options + assert not any(o[0] == 'grpc.default_authority' for o in bob_options) send_proxy = ray.get_actor("SendProxyActor") _assert_on_send_proxy(send_proxy) @@ -86,6 +81,69 @@ def test_grpc_options(): assert p_alice.exitcode == 0 and p_bob.exitcode == 0 +def party_grpc_options(party): + compatible_utils.init_ray(address='local') + cluster = { + 'alice': { + 'address': '127.0.0.1:11010', + 'cross_silo_comm_config': CrossSiloGrpcCommConfig( + grpc_channel_options=[ + ('grpc.default_authority', 'alice'), + ('grpc.max_send_message_length', 51 * 1024 * 1024) + ]) + }, + 'bob': { + 'address': '127.0.0.1:11011', + 'cross_silo_comm_config': CrossSiloGrpcCommConfig( + grpc_channel_options=[ + ('grpc.default_authority', 'bob'), + ('grpc.max_send_message_length', 50 * 1024 * 1024) + ]) + }, + } + fed.init( + cluster=cluster, + party=party, + global_cross_silo_comm_config=CrossSiloCommConfig( + messages_max_size_in_bytes=100) + ) + + def _assert_on_send_proxy(proxy_actor): + alice_config = ray.get(proxy_actor._get_proxy_config.remote('alice')) + assert 'grpc_options' in alice_config + alice_options = alice_config['grpc_options'] + assert ('grpc.max_send_message_length', 51 * 1024 * 1024) in alice_options + assert ('grpc.default_authority', 'alice') in alice_options + + bob_config = ray.get(proxy_actor._get_proxy_config.remote('bob')) + assert 'grpc_options' in bob_config + bob_options = bob_config['grpc_options'] + assert ('grpc.max_send_message_length', 50 * 1024 * 1024) in bob_options + assert ('grpc.default_authority', 'bob') in bob_options + + send_proxy = ray.get_actor("SendProxyActor") + _assert_on_send_proxy(send_proxy) + + a = dummpy.party('alice').remote() + b = dummpy.party('bob').remote() + fed.get([a, b]) + + fed.shutdown() + ray.shutdown() + + +def test_party_specific_grpc_options(): + p_alice = multiprocessing.Process( + target=party_grpc_options, args=('alice',)) + p_bob = multiprocessing.Process( + target=party_grpc_options, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + if __name__ == "__main__": import sys diff --git a/tests/test_party_specific_grpc_options.py b/tests/test_party_specific_grpc_options.py deleted file mode 100644 index 1c017ca9..00000000 --- a/tests/test_party_specific_grpc_options.py +++ /dev/null @@ -1,78 +0,0 @@ -import multiprocessing -import pytest -import fed -import fed._private.compatible_utils as compatible_utils -import ray - -from fed.config import CrossSiloCommConfig - - -@fed.remote -def dummpy(): - return 2 - - -def party_grpc_options(party): - compatible_utils.init_ray(address='local') - cluster = { - 'alice': { - 'address': '127.0.0.1:11010', - 'grpc_channel_option': [ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 51 * 1024 * 1024) - ]}, - 'bob': { - 'address': '127.0.0.1:11011', - 'grpc_channel_option': [ - ('grpc.default_authority', 'bob'), - ('grpc.max_send_message_length', 50 * 1024 * 1024) - ]}, - } - fed.init( - cluster=cluster, - party=party, - cross_silo_comm_config=CrossSiloCommConfig( - messages_max_size_in_bytes=100) - ) - - def _assert_on_proxy(proxy_actor): - cluster_info = ray.get(proxy_actor._get_cluster_info.remote()) - assert cluster_info['alice'] is not None - assert cluster_info['alice']['grpc_channel_option'] is not None - alice_channel_options = cluster_info['alice']['grpc_channel_option'] - assert ('grpc.default_authority', 'alice') in alice_channel_options - assert ('grpc.max_send_message_length', 51 * 1024 * 1024) in alice_channel_options # noqa - - assert cluster_info['bob'] is not None - assert cluster_info['bob']['grpc_channel_option'] is not None - bob_channel_options = cluster_info['bob']['grpc_channel_option'] - assert ('grpc.default_authority', 'bob') in bob_channel_options - assert ('grpc.max_send_message_length', 50 * 1024 * 1024) in bob_channel_options # noqa - - send_proxy = ray.get_actor("SendProxyActor") - _assert_on_proxy(send_proxy) - - a = dummpy.party('alice').remote() - b = dummpy.party('bob').remote() - fed.get([a, b]) - - fed.shutdown() - ray.shutdown() - - -def test_party_specific_grpc_options(): - p_alice = multiprocessing.Process( - target=party_grpc_options, args=('alice',)) - p_bob = multiprocessing.Process( - target=party_grpc_options, args=('bob',)) - p_alice.start() - p_bob.start() - p_alice.join() - p_bob.join() - assert p_alice.exitcode == 0 and p_bob.exitcode == 0 - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main(["-sv", __file__])) diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index c2a45d8b..15f37fc0 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -53,7 +53,7 @@ def run(party, is_inner_party): fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloGrpcCommConfig( + global_cross_silo_comm_config=CrossSiloGrpcCommConfig( grpc_retry_policy=retry_policy ) ) diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index a8f53f68..c95fcefa 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -65,7 +65,7 @@ def run_failure(party): fed.init( cluster=cluster, party=party, - cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloCommConfig( send_resource_label=send_proxy_resources, recv_resource_label=recv_proxy_resources, timeout_in_seconds=10, From 2f9ce5756063b1c886fb62e73b9ca3397ba47d81 Mon Sep 17 00:00:00 2001 From: paer Date: Thu, 13 Jul 2023 11:10:00 +0800 Subject: [PATCH 15/22] fix mock proxy UT --- fed/config.py | 21 ++-------- fed/proxy/barriers.py | 6 +-- tests/test_transport_proxy.py | 66 ++++++++++++++++++++----------- tests/test_transport_proxy_tls.py | 7 ++++ 4 files changed, 58 insertions(+), 42 deletions(-) diff --git a/fed/config.py b/fed/config.py index 5689c03c..ce3f5e86 100644 --- a/fed/config.py +++ b/fed/config.py @@ -50,7 +50,7 @@ def __init__(self, raw_bytes: bytes) -> None: @property def cross_silo_comm_config(self): - return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG, {}) + return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG, CrossSiloCommConfig()) # A module level cache for the cluster configurations. @@ -85,7 +85,7 @@ class CrossSiloCommConfig: """A class to store parameters used for Proxy Actor Attributes: - proxier_fo_max_retries: The max restart times for the send proxy. + proxy_max_restarts: The max restart times for the send proxy. serializing_allowed_list: The package or class list allowed for serializing(deserializating) cross silos. It's used for avoiding pickle deserializing execution attack when crossing solis. @@ -111,7 +111,7 @@ class CrossSiloCommConfig: """ def __init__( self, - proxier_fo_max_retries: int = None, + proxy_max_restarts: int = None, timeout_in_seconds: int = 60, messages_max_size_in_bytes: int = None, exit_on_sending_failure: Optional[bool] = False, @@ -119,7 +119,7 @@ def __init__( send_resource_label: Optional[Dict[str, str]] = None, recv_resource_label: Optional[Dict[str, str]] = None, http_header: Optional[Dict[str, str]] = None) -> None: - self.proxier_fo_max_retries = proxier_fo_max_retries + self.proxy_max_restarts = proxy_max_restarts self.timeout_in_seconds = timeout_in_seconds self.messages_max_size_in_bytes = messages_max_size_in_bytes self.exit_on_sending_failure = exit_on_sending_failure @@ -170,16 +170,3 @@ def __init__(self, super().__init__(*args, **kwargs) self.grpc_retry_policy = grpc_retry_policy self.grpc_channel_options = grpc_channel_options - - -class CrossSiloBRPCConfig(CrossSiloCommConfig): - """A class to store parameters used for GRPC communication - - Attributes: - """ - def __init__(self, - brpc_options, - *args, - **kwargs): - super().__init__(*args, **kwargs) - self.brpc_options = brpc_options diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 7e862770..3e2c7bc4 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -91,7 +91,7 @@ def __init__( listen_addr: str, party: str, tls_config: Dict, - proxy_config: CrossSiloCommConfig + proxy_config: CrossSiloCommConfig=None ) -> None: self._listen_addr = listen_addr self._party = party @@ -289,9 +289,9 @@ def start_send_proxy( global _SEND_PROXY_ACTOR actor_options = copy.deepcopy(_DEFAULT_SEND_PROXY_OPTIONS) - if proxy_config and proxy_config.proxier_fo_max_retries: + if proxy_config and proxy_config.proxy_max_restarts: actor_options.update({ - "max_task_retries": proxy_config.proxier_fo_max_retries, + "max_task_retries": proxy_config.proxy_max_restarts, "max_restarts": 1, }) if proxy_config and proxy_config.send_resource_label: diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index ae5274cf..5dd583c1 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -20,10 +20,17 @@ import fed._private.compatible_utils as compatible_utils +from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig from fed._private import constants from fed._private import global_context from fed.grpc import fed_pb2, fed_pb2_grpc -from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy +from fed.proxy.barriers import ( + send, + start_recv_proxy, + start_send_proxy, + RecvProxy +) +from fed.proxy.grpc_proxy import GrpcSendProxy, GrpcRecvProxy def test_n_to_1_transport(): @@ -38,9 +45,6 @@ def test_n_to_1_transport(): constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {}, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60, } compatible_utils._init_internal_kv() compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, @@ -50,12 +54,21 @@ def test_n_to_1_transport(): SERVER_ADDRESS = "127.0.0.1:12344" party = 'test_party' cluster_config = {'test_party': {'address': SERVER_ADDRESS}} + config = CrossSiloGrpcCommConfig() start_recv_proxy( cluster_config, party, logging_level='info', + proxy_cls=GrpcRecvProxy, + proxy_config=config + ) + start_send_proxy( + cluster_config, + party, + logging_level='info', + proxy_cls=GrpcSendProxy, + proxy_config=config ) - start_send_proxy(cluster_config, party, logging_level='info') sent_objs = [] get_objs = [] @@ -147,7 +160,7 @@ def _test_start_recv_proxy( ).remote( listen_addr=listen_addr, party=party, - expected_metadata=expected_metadata, + expected_metadata=expected_metadata ) recver_proxy_actor.run_grpc_server.remote() assert ray.get(recver_proxy_actor.is_ready.remote()) @@ -159,14 +172,14 @@ def test_send_grpc_with_meta(): constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {}, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60, } + metadata = {"key": "value"} + send_proxy_config = CrossSiloCommConfig( + http_header=metadata + ) job_config = { - constants.KEY_OF_GRPC_METADATA: { - "key": "value" - } + constants.KEY_OF_CROSS_SILO_COMM_CONFIG: + send_proxy_config, } compatible_utils._init_internal_kv() compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, @@ -180,9 +193,14 @@ def test_send_grpc_with_meta(): cluster_config = {'test_party': {'address': SERVER_ADDRESS}} _test_start_recv_proxy( cluster_config, party, logging_level='info', - expected_metadata={"key": "value"}, + expected_metadata=metadata, ) - start_send_proxy(cluster_config, party, logging_level='info') + start_send_proxy( + cluster_config, + party, + logging_level='info', + proxy_cls=GrpcSendProxy, + proxy_config=CrossSiloGrpcCommConfig()) sent_objs = [] sent_obj = send(party, "data", 0, 1) sent_objs.append(sent_obj) @@ -200,14 +218,12 @@ def test_send_grpc_with_party_specific_meta(): constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {}, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60, } + send_proxy_config = CrossSiloCommConfig( + http_header={"key": "value"}) job_config = { - constants.KEY_OF_GRPC_METADATA: { - "key": "value" - } + constants.KEY_OF_CROSS_SILO_COMM_CONFIG: + send_proxy_config, } compatible_utils._init_internal_kv() compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, @@ -221,14 +237,20 @@ def test_send_grpc_with_party_specific_meta(): cluster_parties_config = { 'test_party': { 'address': SERVER_ADDRESS, - 'grpc_metadata': (('token', 'test-party-token'),) + 'cross_silo_comm_config': CrossSiloCommConfig( + http_header={"token": "test-party-token"}) } } _test_start_recv_proxy( cluster_parties_config, party, logging_level='info', expected_metadata={"key": "value", "token": "test-party-token"}, ) - start_send_proxy(cluster_parties_config, party, logging_level='info') + start_send_proxy( + cluster_parties_config, + party, + logging_level='info', + proxy_cls=GrpcSendProxy, + proxy_config=send_proxy_config) sent_objs = [] sent_obj = send(party, "data", 0, 1) sent_objs.append(sent_obj) diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index 65d24f9b..fc9a4fd1 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -22,6 +22,8 @@ from fed._private import constants from fed._private import global_context from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy +from fed.proxy.grpc_proxy import GrpcSendProxy, GrpcRecvProxy +from fed.config import CrossSiloGrpcCommConfig def test_n_to_1_transport(): @@ -58,17 +60,22 @@ def test_n_to_1_transport(): SERVER_ADDRESS = "127.0.0.1:65422" party = 'test_party' cluster_config = {'test_party': {'address': SERVER_ADDRESS}} + config = CrossSiloGrpcCommConfig() start_recv_proxy( cluster_config, party, logging_level='info', tls_config=tls_config, + proxy_cls=GrpcRecvProxy, + proxy_config=config ) start_send_proxy( cluster_config, party, logging_level='info', tls_config=tls_config, + proxy_cls=GrpcSendProxy, + proxy_config=config ) sent_objs = [] From dec615326528d611a7d0cef2a002494655a98c75 Mon Sep 17 00:00:00 2001 From: paer Date: Thu, 13 Jul 2023 19:45:43 +0800 Subject: [PATCH 16/22] per party config & timeout ms & listen IPv6 Signed-off-by: paer --- fed/_private/constants.py | 6 -- fed/api.py | 13 --- fed/config.py | 167 ++++++++++++++++++++++-------- fed/proxy/barriers.py | 4 +- fed/proxy/grpc_proxy.py | 4 +- tests/test_listen_addr.py | 56 +++++----- tests/test_setup_proxy_actor.py | 2 +- tests/test_transport_proxy_tls.py | 3 - 8 files changed, 164 insertions(+), 91 deletions(-) diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 3b8e9ddd..96681e37 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -27,12 +27,6 @@ KEY_OF_CROSS_SILO_COMM_CONFIG = "CROSS_SILO_COMM_CONFIG" -KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST = "CROSS_SILO_SERIALIZING_ALLOWED_LIST" # noqa - -KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES = "CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES" # noqa - -KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS = "CROSS_SILO_TIMEOUT_IN_SECONDS" - RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S" diff --git a/fed/api.py b/fed/api.py index 4a62aafb..0ca28866 100644 --- a/fed/api.py +++ b/fed/api.py @@ -70,12 +70,6 @@ def init( # used if not provided. 'listen_addr': '0.0.0.0:10001', 'cross_silo_comm_config': CrossSiloCommConfig - # (Optional) The party specific metadata sent with grpc requests - 'grpc_metadata': (('token', 'alice-token'),), - 'grpc_options': [ - ('grpc.default_authority', 'alice'), - ('grpc.max_send_message_length', 50 * 1024 * 1024) - ] }, 'bob': { # The address for other parties. @@ -122,13 +116,6 @@ def init( global_cross_silo_comm_config: Global cross-silo communication related config that are applied to all connections. Supported configs can refer to CrossSiloCommConfig in config.py. - dest_party_comm_config: Communication config for the destination party - specifed by the key. E.g. - .. code:: python - { - 'alice': alice_CrossSiloCommConfig, - 'bob': bob_CrossSiloCommConfig - } Examples: >>> import fed diff --git a/fed/config.py b/fed/config.py index ce3f5e86..a13d922b 100644 --- a/fed/config.py +++ b/fed/config.py @@ -7,9 +7,11 @@ import fed._private.compatible_utils as compatible_utils import fed._private.constants as fed_constants import cloudpickle -from typing import Dict, List, Optional import json +from typing import Dict, List, Optional +from dataclasses import dataclass + class ClusterConfig: """A local cache of cluster configuration items.""" @@ -28,18 +30,6 @@ def current_party(self): def tls_config(self): return self._data[fed_constants.KEY_OF_TLS_CONFIG] - @property - def serializing_allowed_list(self): - return self._data[fed_constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST] - - @property - def cross_silo_timeout(self): - return self._data[fed_constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS] - - @property - def cross_silo_messages_max_size(self): - return self._data[fed_constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES] - class JobConfig: def __init__(self, raw_bytes: bytes) -> None: @@ -81,6 +71,99 @@ def get_job_config(): return _job_config +# class CrossSiloCommConfig: +# """A class to store parameters used for Proxy Actor + +# Attributes: +# proxy_max_restarts: The max restart times for the send proxy. +# serializing_allowed_list: The package or class list allowed for +# serializing(deserializating) cross silos. It's used for avoiding pickle +# deserializing execution attack when crossing solis. +# send_resource_label: Customized resource label, the SendProxyActor +# will be scheduled based on the declared resource label. For example, +# when setting to `{"my_label": 1}`, then the SendProxyActor will be started +# only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. +# recv_resource_label: Customized resource label, the RecverProxyActor +# will be scheduled based on the declared resource label. For example, +# when setting to `{"my_label": 1}`, then the RecverProxyActor will be started +# only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. +# exit_on_sending_failure: whether exit when failure on +# cross-silo sending. If True, a SIGTERM will be signaled to self +# if failed to sending cross-silo data. +# messages_max_size_in_bytes: The maximum length in bytes of +# cross-silo messages. +# If None, the default value of 500 MB is specified. +# timeout_in_seconds: The timeout in seconds of a cross-silo RPC call. +# It's 60 by default. +# http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. +# This won't override basic tcp headers, such as `user-agent`, but concat +# them together. +# """ +# def __init__( +# self, +# proxy_max_restarts: int = None, +# timeout_in_seconds: int = 60, +# messages_max_size_in_bytes: int = None, +# exit_on_sending_failure: Optional[bool] = False, +# serializing_allowed_list: Optional[Dict[str, str]] = None, +# send_resource_label: Optional[Dict[str, str]] = None, +# recv_resource_label: Optional[Dict[str, str]] = None, +# http_header: Optional[Dict[str, str]] = None) -> None: +# self.proxy_max_restarts = proxy_max_restarts +# self.timeout_in_seconds = timeout_in_seconds +# self.messages_max_size_in_bytes = messages_max_size_in_bytes +# self.exit_on_sending_failure = exit_on_sending_failure +# self.serializing_allowed_list = serializing_allowed_list +# self.send_resource_label = send_resource_label +# self.recv_resource_label = recv_resource_label +# self.http_header = http_header + +# def __json__(self): +# return json.dumps(self.__dict__) + +# @classmethod +# def from_json(cls, json_str): +# data = json.loads(json_str) +# return cls(**data) + + +# class CrossSiloGrpcCommConfig(CrossSiloCommConfig): +# """A class to store parameters used for GRPC communication + +# Attributes: +# grpc_retry_policy: a dict descibes the retry policy for +# cross silo rpc call. If None, the following default retry policy +# will be used. More details please refer to +# `retry-policy `_. # noqa + +# .. code:: python +# { +# "maxAttempts": 4, +# "initialBackoff": "0.1s", +# "maxBackoff": "1s", +# "backoffMultiplier": 2, +# "retryableStatusCodes": [ +# "UNAVAILABLE" +# ] +# } +# grpc_channel_options: A list of tuples to store GRPC channel options, +# e.g. [ +# ('grpc.enable_retries', 1), +# ('grpc.max_send_message_length', 50 * 1024 * 1024) +# ] +# """ +# def __init__(self, +# grpc_channel_options: List = None, +# grpc_retry_policy: Dict[str, str] = None, +# *args, +# **kwargs): +# super().__init__(*args, **kwargs) +# self.grpc_retry_policy = grpc_retry_policy +# self.grpc_channel_options = grpc_channel_options + + + +@dataclass class CrossSiloCommConfig: """A class to store parameters used for Proxy Actor @@ -103,30 +186,20 @@ class CrossSiloCommConfig: messages_max_size_in_bytes: The maximum length in bytes of cross-silo messages. If None, the default value of 500 MB is specified. - timeout_in_seconds: The timeout in seconds of a cross-silo RPC call. - It's 60 by default. + timeout_in_ms: The timeout in mili-seconds of a cross-silo RPC call. + It's 60000 by default. http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. This won't override basic tcp headers, such as `user-agent`, but concat them together. """ - def __init__( - self, - proxy_max_restarts: int = None, - timeout_in_seconds: int = 60, - messages_max_size_in_bytes: int = None, - exit_on_sending_failure: Optional[bool] = False, - serializing_allowed_list: Optional[Dict[str, str]] = None, - send_resource_label: Optional[Dict[str, str]] = None, - recv_resource_label: Optional[Dict[str, str]] = None, - http_header: Optional[Dict[str, str]] = None) -> None: - self.proxy_max_restarts = proxy_max_restarts - self.timeout_in_seconds = timeout_in_seconds - self.messages_max_size_in_bytes = messages_max_size_in_bytes - self.exit_on_sending_failure = exit_on_sending_failure - self.serializing_allowed_list = serializing_allowed_list - self.send_resource_label = send_resource_label - self.recv_resource_label = recv_resource_label - self.http_header = http_header + proxy_max_restarts: int = None + timeout_in_ms: int = 60000 + messages_max_size_in_bytes: int = None + exit_on_sending_failure: Optional[bool] = False + serializing_allowed_list: Optional[Dict[str, str]] = None + send_resource_label: Optional[Dict[str, str]] = None + recv_resource_label: Optional[Dict[str, str]] = None + http_header: Optional[Dict[str, str]] = None def __json__(self): return json.dumps(self.__dict__) @@ -136,7 +209,25 @@ def from_json(cls, json_str): data = json.loads(json_str) return cls(**data) + @classmethod + def from_dict(cls, data: Dict): + """Initialize CrossSiloCommConfig from a dictionary. + + Args: + data (Dict): Dictionary with keys as member variable names. + + Returns: + CrossSiloCommConfig: An instance of CrossSiloCommConfig. + """ + # Get the attributes of the class + attrs = {attr for attr, _ in cls.__annotations__.items()} + # Filter the dictionary to only include keys that are attributes of the class + filtered_data = {key: value for key, value in data.items() if key in attrs} + return cls(**filtered_data) + + +@dataclass class CrossSiloGrpcCommConfig(CrossSiloCommConfig): """A class to store parameters used for GRPC communication @@ -162,11 +253,5 @@ class CrossSiloGrpcCommConfig(CrossSiloCommConfig): ('grpc.max_send_message_length', 50 * 1024 * 1024) ] """ - def __init__(self, - grpc_channel_options: List = None, - grpc_retry_policy: Dict[str, str] = None, - *args, - **kwargs): - super().__init__(*args, **kwargs) - self.grpc_retry_policy = grpc_retry_policy - self.grpc_channel_options = grpc_channel_options + grpc_channel_options: List = None + grpc_retry_policy: Dict[str, str] = None diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 3e2c7bc4..b9efe9c6 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -265,7 +265,7 @@ def start_recv_proxy( proxy_cls=proxy_cls ) recver_proxy_actor.start.remote() - timeout = proxy_config.timeout_in_seconds if proxy_config is not None else 60 + timeout = proxy_config.timeout_in_ms / 1000 if proxy_config is not None else 60 server_state = ray.get(recver_proxy_actor.is_ready.remote(), timeout=timeout) assert server_state[0], server_state[1] logger.info("RecverProxy has successfully created.") @@ -308,7 +308,7 @@ def start_send_proxy( logging_level=logging_level, proxy_cls=proxy_cls ) - timeout = get_job_config().cross_silo_comm_config.timeout_in_seconds + timeout = get_job_config().cross_silo_comm_config.timeout_in_ms / 1000 assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout) logger.info("SendProxyActor has successfully created.") diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc_proxy.py index 233c0338..58ad8933 100644 --- a/fed/proxy/grpc_proxy.py +++ b/fed/proxy/grpc_proxy.py @@ -97,7 +97,7 @@ async def send( stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub - timeout = self._proxy_config.timeout_in_seconds + timeout = self._proxy_config.timeout_in_ms / 1000 response = await send_data_grpc( data=data, stub=self._stubs[dest_party], @@ -274,6 +274,7 @@ async def _run_grpc_server( port, event, all_data, party, lock, server_ready_future, tls_config=None, grpc_options=None ): + print(f"ReceiveProxy binding port {port}, options: {grpc_options}...") server = grpc.aio.server(options=grpc_options) fed_pb2_grpc.add_GrpcServiceServicer_to_server( SendDataService(event, all_data, party, lock), server @@ -290,6 +291,7 @@ async def _run_grpc_server( server.add_secure_port(f'[::]:{port}', server_credentials) else: server.add_insecure_port(f'[::]:{port}') + # server.add_insecure_port(f'[::]:{port}') msg = f"Succeeded to add port {port}." await server.start() diff --git a/tests/test_listen_addr.py b/tests/test_listen_addr.py index 72753e2a..05960e2f 100644 --- a/tests/test_listen_addr.py +++ b/tests/test_listen_addr.py @@ -72,29 +72,36 @@ def run(party): compatible_utils.init_ray(address='local') occupied_port = 11020 - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # NOTE(NKcqx): Firstly try to bind IPv6 because the grpc server will do so. + # Otherwise this UT will false because socket bind $occupied_port + # on IPv4 address while grpc server listendn Ipv6 address. + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) # Pre-occuping the port - s.bind(("localhost", occupied_port)) - - cluster = { - 'alice': { - 'address': '127.0.0.1:11012', - 'listen_addr': f'0.0.0.0:{occupied_port}'}, - 'bob': { - 'address': '127.0.0.1:11011', - 'listen_addr': '0.0.0.0:11011'}, - } - - # Starting grpc server on an used port will cause AssertionError - with pytest.raises(AssertionError): - fed.init(cluster=cluster, party=party) - - import time - - time.sleep(5) - s.close() - fed.shutdown() - ray.shutdown() + s.bind(("::", occupied_port)) + except OSError: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", occupied_port)) + + cluster = { + 'alice': { + 'address': '127.0.0.1:11012', + 'listen_addr': f'0.0.0.0:{occupied_port}'}, + 'bob': { + 'address': '127.0.0.1:11011', + 'listen_addr': '0.0.0.0:11011'}, + } + + # Starting grpc server on an used port will cause AssertionError + with pytest.raises(AssertionError): + fed.init(cluster=cluster, party=party) + + import time + + time.sleep(5) + s.close() + fed.shutdown() + ray.shutdown() p_alice = multiprocessing.Process(target=run, args=('alice',)) p_alice.start() @@ -103,6 +110,7 @@ def run(party): if __name__ == "__main__": - import sys + # import sys - sys.exit(pytest.main(["-sv", __file__])) + # sys.exit(pytest.main(["-sv", __file__])) + test_listen_used_addr() diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index c95fcefa..9d1ee40b 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -68,7 +68,7 @@ def run_failure(party): global_cross_silo_comm_config=CrossSiloCommConfig( send_resource_label=send_proxy_resources, recv_resource_label=recv_proxy_resources, - timeout_in_seconds=10, + timeout_in_ms=10*1000, ) ) diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index fc9a4fd1..6b7a7d3f 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -46,9 +46,6 @@ def test_n_to_1_transport(): constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: tls_config, - constants.KEY_OF_CROSS_SILO_MESSAGES_MAX_SIZE_IN_BYTES: None, - constants.KEY_OF_CROSS_SILO_SERIALIZING_ALLOWED_LIST: {}, - constants.KEY_OF_CROSS_SILO_TIMEOUT_IN_SECONDS: 60, } global_context.get_global_context().get_cleanup_manager().start() From cd08616d69108aa4317e4b39b16fb77ae0e2a731 Mon Sep 17 00:00:00 2001 From: paer Date: Thu, 13 Jul 2023 20:03:33 +0800 Subject: [PATCH 17/22] lint codes --- fed/config.py | 97 ++--------------------------------- fed/proxy/barriers.py | 5 +- fed/proxy/grpc_proxy.py | 7 +-- tests/test_transport_proxy.py | 1 + 4 files changed, 11 insertions(+), 99 deletions(-) diff --git a/fed/config.py b/fed/config.py index a13d922b..21c495da 100644 --- a/fed/config.py +++ b/fed/config.py @@ -40,7 +40,9 @@ def __init__(self, raw_bytes: bytes) -> None: @property def cross_silo_comm_config(self): - return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG, CrossSiloCommConfig()) + return self._data.get( + fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG, + CrossSiloCommConfig()) # A module level cache for the cluster configurations. @@ -71,98 +73,6 @@ def get_job_config(): return _job_config -# class CrossSiloCommConfig: -# """A class to store parameters used for Proxy Actor - -# Attributes: -# proxy_max_restarts: The max restart times for the send proxy. -# serializing_allowed_list: The package or class list allowed for -# serializing(deserializating) cross silos. It's used for avoiding pickle -# deserializing execution attack when crossing solis. -# send_resource_label: Customized resource label, the SendProxyActor -# will be scheduled based on the declared resource label. For example, -# when setting to `{"my_label": 1}`, then the SendProxyActor will be started -# only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. -# recv_resource_label: Customized resource label, the RecverProxyActor -# will be scheduled based on the declared resource label. For example, -# when setting to `{"my_label": 1}`, then the RecverProxyActor will be started -# only on Nodes with `{"resource": {"my_label": $NUM}}` where $NUM >= 1. -# exit_on_sending_failure: whether exit when failure on -# cross-silo sending. If True, a SIGTERM will be signaled to self -# if failed to sending cross-silo data. -# messages_max_size_in_bytes: The maximum length in bytes of -# cross-silo messages. -# If None, the default value of 500 MB is specified. -# timeout_in_seconds: The timeout in seconds of a cross-silo RPC call. -# It's 60 by default. -# http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. -# This won't override basic tcp headers, such as `user-agent`, but concat -# them together. -# """ -# def __init__( -# self, -# proxy_max_restarts: int = None, -# timeout_in_seconds: int = 60, -# messages_max_size_in_bytes: int = None, -# exit_on_sending_failure: Optional[bool] = False, -# serializing_allowed_list: Optional[Dict[str, str]] = None, -# send_resource_label: Optional[Dict[str, str]] = None, -# recv_resource_label: Optional[Dict[str, str]] = None, -# http_header: Optional[Dict[str, str]] = None) -> None: -# self.proxy_max_restarts = proxy_max_restarts -# self.timeout_in_seconds = timeout_in_seconds -# self.messages_max_size_in_bytes = messages_max_size_in_bytes -# self.exit_on_sending_failure = exit_on_sending_failure -# self.serializing_allowed_list = serializing_allowed_list -# self.send_resource_label = send_resource_label -# self.recv_resource_label = recv_resource_label -# self.http_header = http_header - -# def __json__(self): -# return json.dumps(self.__dict__) - -# @classmethod -# def from_json(cls, json_str): -# data = json.loads(json_str) -# return cls(**data) - - -# class CrossSiloGrpcCommConfig(CrossSiloCommConfig): -# """A class to store parameters used for GRPC communication - -# Attributes: -# grpc_retry_policy: a dict descibes the retry policy for -# cross silo rpc call. If None, the following default retry policy -# will be used. More details please refer to -# `retry-policy `_. # noqa - -# .. code:: python -# { -# "maxAttempts": 4, -# "initialBackoff": "0.1s", -# "maxBackoff": "1s", -# "backoffMultiplier": 2, -# "retryableStatusCodes": [ -# "UNAVAILABLE" -# ] -# } -# grpc_channel_options: A list of tuples to store GRPC channel options, -# e.g. [ -# ('grpc.enable_retries', 1), -# ('grpc.max_send_message_length', 50 * 1024 * 1024) -# ] -# """ -# def __init__(self, -# grpc_channel_options: List = None, -# grpc_retry_policy: Dict[str, str] = None, -# *args, -# **kwargs): -# super().__init__(*args, **kwargs) -# self.grpc_retry_policy = grpc_retry_policy -# self.grpc_channel_options = grpc_channel_options - - - @dataclass class CrossSiloCommConfig: """A class to store parameters used for Proxy Actor @@ -226,7 +136,6 @@ def from_dict(cls, data: Dict): return cls(**filtered_data) - @dataclass class CrossSiloGrpcCommConfig(CrossSiloCommConfig): """A class to store parameters used for GRPC communication diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index b28f0484..74c56a50 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -62,7 +62,7 @@ def __init__( cluster: Dict, party: str, tls_config: Dict, - proxy_config: CrossSiloCommConfig=None + proxy_config: CrossSiloCommConfig = None ) -> None: self._cluster = cluster self._party = party @@ -92,7 +92,7 @@ def __init__( listen_addr: str, party: str, tls_config: Dict, - proxy_config: CrossSiloCommConfig=None + proxy_config: CrossSiloCommConfig = None ) -> None: self._listen_addr = listen_addr self._party = party @@ -184,6 +184,7 @@ async def _get_cluster_info(self): async def _get_proxy_config(self, dest_party=None): return await self._proxy_instance.get_proxy_config(dest_party) + @ray.remote class RecverProxyActor: def __init__( diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc_proxy.py index 3ce0a002..3a5a2c79 100644 --- a/fed/proxy/grpc_proxy.py +++ b/fed/proxy/grpc_proxy.py @@ -38,7 +38,7 @@ def parse_grpc_options(proxy_config: CrossSiloCommConfig): if proxy_config is not None: if proxy_config.messages_max_size_in_bytes is not None: grpc_channel_options.update({ - 'grpc.max_send_message_length': + 'grpc.max_send_message_length': proxy_config.messages_max_size_in_bytes, 'grpc.max_receive_message_length': proxy_config.messages_max_size_in_bytes @@ -70,7 +70,7 @@ def __init__( cluster: Dict, party: str, tls_config: Dict, - proxy_config: CrossSiloCommConfig=None + proxy_config: CrossSiloCommConfig = None ) -> None: super().__init__(cluster, party, tls_config, proxy_config) self._grpc_metadata = proxy_config.http_header or {} @@ -100,7 +100,8 @@ async def send( channel = grpc.aio.secure_channel( dest_addr, credentials, options=grpc_channel_options) else: - channel = grpc.aio.insecure_channel(dest_addr, options=grpc_channel_options) + channel = grpc.aio.insecure_channel( + dest_addr, options=grpc_channel_options) stub = fed_pb2_grpc.GrpcServiceStub(channel) self._stubs[dest_party] = stub diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 3efa816c..1d7ff8e7 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -37,6 +37,7 @@ from fed.grpc import fed_pb2_in_protobuf3 as fed_pb2 from fed.grpc import fed_pb2_grpc_in_protobuf3 as fed_pb2_grpc + def test_n_to_1_transport(): """This case is used to test that we have N send_op barriers, sending data to the target recver proxy, and there also have From 2ab6483a839c0f09f3bb42835ee34a4333b0a38f Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 18 Jul 2023 10:37:47 +0800 Subject: [PATCH 18/22] update UT for removing messages_max_size_in_bytes --- fed/proxy/grpc_proxy.py | 16 +++++----------- tests/test_grpc_options_on_proxies.py | 9 ++++++--- tests/test_grpc_options_per_party.py | 14 ++++++++++---- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc_proxy.py index 3a5a2c79..a02bcc51 100644 --- a/fed/proxy/grpc_proxy.py +++ b/fed/proxy/grpc_proxy.py @@ -5,14 +5,13 @@ import logging import threading import json -import importlib.metadata from typing import Dict import fed.utils as fed_utils from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig -from fed._private.compatible_utils import _compare_version_strings +import fed._private.compatible_utils as compatible_utils from fed._private.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE from fed.proxy.barriers import ( add_two_dim_dict, @@ -22,7 +21,8 @@ SendProxy, RecvProxy ) -if _compare_version_strings(importlib.metadata.version('protobuf'), '4.0.0'): +if compatible_utils._compare_version_strings( + fed_utils.get_package_version('protobuf'), '4.0.0'): from fed.grpc import fed_pb2_in_protobuf4 as fed_pb2 from fed.grpc import fed_pb2_grpc_in_protobuf4 as fed_pb2_grpc else: @@ -35,14 +35,8 @@ def parse_grpc_options(proxy_config: CrossSiloCommConfig): grpc_channel_options = {} - if proxy_config is not None: - if proxy_config.messages_max_size_in_bytes is not None: - grpc_channel_options.update({ - 'grpc.max_send_message_length': - proxy_config.messages_max_size_in_bytes, - 'grpc.max_receive_message_length': - proxy_config.messages_max_size_in_bytes - }) + if proxy_config is not None and isinstance( + proxy_config, CrossSiloGrpcCommConfig): if isinstance(proxy_config, CrossSiloGrpcCommConfig): if proxy_config.grpc_channel_options is not None: grpc_channel_options.update(proxy_config.grpc_channel_options) diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index b70bf0e6..555b1f23 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -18,7 +18,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloCommConfig +from fed.config import CrossSiloGrpcCommConfig @fed.remote @@ -35,8 +35,11 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloCommConfig( - messages_max_size_in_bytes=100) + global_cross_silo_comm_config=CrossSiloGrpcCommConfig( + grpc_channel_options=[( + 'grpc.max_send_message_length', 100 + )] + ) ) def _assert_on_proxy(proxy_actor): diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py index 53b70cbd..fe2ff786 100644 --- a/tests/test_grpc_options_per_party.py +++ b/tests/test_grpc_options_per_party.py @@ -42,8 +42,11 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloCommConfig( - messages_max_size_in_bytes=100) + global_cross_silo_comm_config=CrossSiloGrpcCommConfig( + grpc_channel_options=[( + 'grpc.max_send_message_length', 100 + )] + ) ) def _assert_on_send_proxy(proxy_actor): @@ -104,8 +107,11 @@ def party_grpc_options(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloCommConfig( - messages_max_size_in_bytes=100) + global_cross_silo_comm_config=CrossSiloGrpcCommConfig( + grpc_channel_options=[( + 'grpc.max_send_message_length', 100 + )] + ) ) def _assert_on_send_proxy(proxy_actor): From aebf1071993a3bd35d6661aab7bc4d9f613eefbd Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 18 Jul 2023 11:41:40 +0800 Subject: [PATCH 19/22] rename to CrossSiloMsgConfig --- fed/_private/constants.py | 2 +- fed/api.py | 24 +++--- fed/config.py | 12 +-- fed/proxy/barriers.py | 81 ++----------------- fed/proxy/base_proxy.py | 80 ++++++++++++++++++ fed/{_private => proxy/grpc}/grpc_options.py | 0 fed/proxy/{ => grpc}/grpc_proxy.py | 44 +++++++--- .../test_unpickle_with_whitelist.py | 4 +- tests/test_exit_on_failure_sending.py | 4 +- tests/test_grpc_options_on_proxies.py | 4 +- tests/test_grpc_options_per_party.py | 12 +-- tests/test_retry_policy.py | 4 +- tests/test_setup_proxy_actor.py | 4 +- tests/test_transport_proxy.py | 18 ++--- tests/test_transport_proxy_tls.py | 6 +- 15 files changed, 168 insertions(+), 131 deletions(-) create mode 100644 fed/proxy/base_proxy.py rename fed/{_private => proxy/grpc}/grpc_options.py (100%) rename fed/proxy/{ => grpc}/grpc_proxy.py (87%) diff --git a/fed/_private/constants.py b/fed/_private/constants.py index 96681e37..09af728f 100644 --- a/fed/_private/constants.py +++ b/fed/_private/constants.py @@ -25,7 +25,7 @@ KEY_OF_TLS_CONFIG = "TLS_CONFIG" -KEY_OF_CROSS_SILO_COMM_CONFIG = "CROSS_SILO_COMM_CONFIG" +KEY_OF_CROSS_SILO_MSG_CONFIG = "CROSS_SILO_MSG_CONFIG" RAYFED_LOG_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- %(message)s" # noqa diff --git a/fed/api.py b/fed/api.py index 2ff4eddd..fcf6f974 100644 --- a/fed/api.py +++ b/fed/api.py @@ -33,11 +33,9 @@ send, start_recv_proxy, start_send_proxy, - SendProxy, - RecvProxy ) -from fed.config import CrossSiloCommConfig - +from fed.proxy.grpc.grpc_proxy import SendProxy, RecvProxy +from fed.config import CrossSiloMsgConfig from fed.fed_object import FedObject from fed.utils import is_ray_object_refs, setup_logger @@ -52,7 +50,7 @@ def init( enable_waiting_for_other_parties_ready: bool = False, send_proxy_cls: SendProxy = None, recv_proxy_cls: RecvProxy = None, - global_cross_silo_comm_config: Optional[CrossSiloCommConfig] = None, + global_cross_silo_comm_config: Optional[CrossSiloMsgConfig] = None, **kwargs, ): """ @@ -69,7 +67,7 @@ def init( # (Optional) the listen address, the `address` will be # used if not provided. 'listen_addr': '0.0.0.0:10001', - 'cross_silo_comm_config': CrossSiloCommConfig + 'cross_silo_comm_config': CrossSiloMsgConfig }, 'bob': { # The address for other parties. @@ -115,7 +113,7 @@ def init( are all ready if True. global_cross_silo_comm_config: Global cross-silo communication related config that are applied to all connections. Supported configs - can refer to CrossSiloCommConfig in config.py. + can refer to CrossSiloMsgConfig in config.py. Examples: >>> import fed @@ -142,7 +140,7 @@ def init( ), 'Cert or key are not in tls_config.' global_cross_silo_comm_config = \ - global_cross_silo_comm_config or CrossSiloCommConfig() + global_cross_silo_comm_config or CrossSiloMsgConfig() # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv() @@ -153,7 +151,7 @@ def init( } job_config = { - constants.KEY_OF_CROSS_SILO_COMM_CONFIG: + constants.KEY_OF_CROSS_SILO_MSG_CONFIG: global_cross_silo_comm_config, } compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, @@ -175,7 +173,9 @@ def init( exit_when_failure_sending=global_cross_silo_comm_config.exit_on_sending_failure) if recv_proxy_cls is None: - from fed.proxy.grpc_proxy import GrpcRecvProxy + logger.debug( + "Not declaring recver proxy class, using `GrpcRecvProxy` as default.") + from fed.proxy.grpc.grpc_proxy import GrpcRecvProxy recv_proxy_cls = GrpcRecvProxy # Start recv proxy start_recv_proxy( @@ -188,7 +188,9 @@ def init( ) if send_proxy_cls is None: - from fed.proxy.grpc_proxy import GrpcSendProxy + logger.debug( + "Not declaring send proxy class, using `GrpcSendProxy` as default.") + from fed.proxy.grpc.grpc_proxy import GrpcSendProxy send_proxy_cls = GrpcSendProxy start_send_proxy( cluster=cluster, diff --git a/fed/config.py b/fed/config.py index 21c495da..8c179adc 100644 --- a/fed/config.py +++ b/fed/config.py @@ -41,8 +41,8 @@ def __init__(self, raw_bytes: bytes) -> None: @property def cross_silo_comm_config(self): return self._data.get( - fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG, - CrossSiloCommConfig()) + fed_constants.KEY_OF_CROSS_SILO_MSG_CONFIG, + CrossSiloMsgConfig()) # A module level cache for the cluster configurations. @@ -74,7 +74,7 @@ def get_job_config(): @dataclass -class CrossSiloCommConfig: +class CrossSiloMsgConfig: """A class to store parameters used for Proxy Actor Attributes: @@ -121,13 +121,13 @@ def from_json(cls, json_str): @classmethod def from_dict(cls, data: Dict): - """Initialize CrossSiloCommConfig from a dictionary. + """Initialize CrossSiloMsgConfig from a dictionary. Args: data (Dict): Dictionary with keys as member variable names. Returns: - CrossSiloCommConfig: An instance of CrossSiloCommConfig. + CrossSiloMsgConfig: An instance of CrossSiloMsgConfig. """ # Get the attributes of the class attrs = {attr for attr, _ in cls.__annotations__.items()} @@ -137,7 +137,7 @@ def from_dict(cls, data: Dict): @dataclass -class CrossSiloGrpcCommConfig(CrossSiloCommConfig): +class GrpcCrossSiloMsgConfig(CrossSiloMsgConfig): """A class to store parameters used for GRPC communication Attributes: diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 981fbafa..47118d6f 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging import time import copy @@ -21,18 +20,10 @@ import ray import fed.config as fed_config -import fed.utils as fed_utils -from fed._private import constants -import fed._private.compatible_utils as compatible_utils -from fed.config import get_job_config, CrossSiloCommConfig -if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0'): - from fed.grpc import fed_pb2_in_protobuf4 as fed_pb2 - from fed.grpc import fed_pb2_grpc_in_protobuf4 as fed_pb2_grpc -else: - from fed.grpc import fed_pb2_in_protobuf3 as fed_pb2 - from fed.grpc import fed_pb2_grpc_in_protobuf3 as fed_pb2_grpc +from fed.config import get_job_config +from fed.proxy.base_proxy import SendProxy, RecvProxy from fed.utils import setup_logger +from fed._private import constants from fed._private.global_context import get_global_context logger = logging.getLogger(__name__) @@ -63,68 +54,6 @@ def pop_from_two_dim_dict(the_dict, key_a, key_b): return the_dict[key_a].pop(key_b) -class SendProxy(abc.ABC): - def __init__( - self, - cluster: Dict, - party: str, - tls_config: Dict, - proxy_config: CrossSiloCommConfig = None - ) -> None: - self._cluster = cluster - self._party = party - self._tls_config = tls_config - self._proxy_config = proxy_config - - @abc.abstractmethod - async def send( - self, - dest_party, - data, - upstream_seq_id, - downstream_seq_id - ): - pass - - async def is_ready(self): - return True - - async def get_proxy_config(self, dest_party=None): - return self._proxy_config - - -class RecvProxy(abc.ABC): - def __init__( - self, - listen_addr: str, - party: str, - tls_config: Dict, - proxy_config: CrossSiloCommConfig = None - ) -> None: - self._listen_addr = listen_addr - self._party = party - self._tls_config = tls_config - self._proxy_config = proxy_config - - @abc.abstractmethod - def start(self): - pass - - @abc.abstractmethod - async def get_data( - self, - src_party, - upstream_seq_id, - curr_seq_id): - pass - - async def is_ready(self): - return True - - async def get_proxy_config(self): - return self._proxy_config - - @ray.remote class SendProxyActor: def __init__( @@ -247,7 +176,7 @@ def start_recv_proxy( logging_level: str, tls_config=None, proxy_cls=None, - proxy_config: Optional[fed_config.CrossSiloCommConfig] = None + proxy_config: Optional[fed_config.CrossSiloMsgConfig] = None ): # Create RecevrProxyActor @@ -292,7 +221,7 @@ def start_send_proxy( logging_level: str, tls_config: Dict = None, proxy_cls=None, - proxy_config: Optional[fed_config.CrossSiloCommConfig] = None + proxy_config: Optional[fed_config.CrossSiloMsgConfig] = None ): # Create SendProxyActor global _SEND_PROXY_ACTOR diff --git a/fed/proxy/base_proxy.py b/fed/proxy/base_proxy.py new file mode 100644 index 00000000..51c0a2fa --- /dev/null +++ b/fed/proxy/base_proxy.py @@ -0,0 +1,80 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Dict + +from fed.config import CrossSiloMsgConfig + + +class SendProxy(abc.ABC): + def __init__( + self, + cluster: Dict, + party: str, + tls_config: Dict, + proxy_config: CrossSiloMsgConfig = None + ) -> None: + self._cluster = cluster + self._party = party + self._tls_config = tls_config + self._proxy_config = proxy_config + + @abc.abstractmethod + async def send( + self, + dest_party, + data, + upstream_seq_id, + downstream_seq_id + ): + pass + + async def is_ready(self): + return True + + async def get_proxy_config(self, dest_party=None): + return self._proxy_config + + +class RecvProxy(abc.ABC): + def __init__( + self, + listen_addr: str, + party: str, + tls_config: Dict, + proxy_config: CrossSiloMsgConfig = None + ) -> None: + self._listen_addr = listen_addr + self._party = party + self._tls_config = tls_config + self._proxy_config = proxy_config + + @abc.abstractmethod + def start(self): + pass + + @abc.abstractmethod + async def get_data( + self, + src_party, + upstream_seq_id, + curr_seq_id): + pass + + async def is_ready(self): + return True + + async def get_proxy_config(self): + return self._proxy_config diff --git a/fed/_private/grpc_options.py b/fed/proxy/grpc/grpc_options.py similarity index 100% rename from fed/_private/grpc_options.py rename to fed/proxy/grpc/grpc_options.py diff --git a/fed/proxy/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py similarity index 87% rename from fed/proxy/grpc_proxy.py rename to fed/proxy/grpc/grpc_proxy.py index a02bcc51..f9d0e968 100644 --- a/fed/proxy/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -1,3 +1,17 @@ +# Copyright 2023 The RayFed Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio import copy import cloudpickle @@ -10,17 +24,16 @@ import fed.utils as fed_utils -from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig +from fed.config import CrossSiloMsgConfig, GrpcCrossSiloMsgConfig import fed._private.compatible_utils as compatible_utils -from fed._private.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE +from fed.proxy.grpc.grpc_options import _DEFAULT_GRPC_CHANNEL_OPTIONS, _GRPC_SERVICE from fed.proxy.barriers import ( add_two_dim_dict, get_from_two_dim_dict, pop_from_two_dim_dict, key_exists_in_two_dim_dict, - SendProxy, - RecvProxy ) +from fed.proxy.base_proxy import SendProxy, RecvProxy if compatible_utils._compare_version_strings( fed_utils.get_package_version('protobuf'), '4.0.0'): from fed.grpc import fed_pb2_in_protobuf4 as fed_pb2 @@ -33,11 +46,24 @@ logger = logging.getLogger(__name__) -def parse_grpc_options(proxy_config: CrossSiloCommConfig): +def parse_grpc_options(proxy_config: CrossSiloMsgConfig): + """ + Extract certain fields in `CrossSiloGrpcCommConfig` into the + "grpc_channel_options". Note that the resulting dict's key + may not be identical to the config name, but a grpc-supported + option name. + + Args: + proxy_config (CrossSiloMsgConfig): The proxy configuration + from which to extract the gRPC options. + + Returns: + dict: A dictionary containing the gRPC channel options. + """ grpc_channel_options = {} if proxy_config is not None and isinstance( - proxy_config, CrossSiloGrpcCommConfig): - if isinstance(proxy_config, CrossSiloGrpcCommConfig): + proxy_config, GrpcCrossSiloMsgConfig): + if isinstance(proxy_config, GrpcCrossSiloMsgConfig): if proxy_config.grpc_channel_options is not None: grpc_channel_options.update(proxy_config.grpc_channel_options) if proxy_config.grpc_retry_policy is not None: @@ -64,7 +90,7 @@ def __init__( cluster: Dict, party: str, tls_config: Dict, - proxy_config: CrossSiloCommConfig = None + proxy_config: CrossSiloMsgConfig = None ) -> None: super().__init__(cluster, party, tls_config, proxy_config) self._grpc_metadata = proxy_config.http_header or {} @@ -174,7 +200,7 @@ def __init__( listen_addr: str, party: str, tls_config: Dict, - proxy_config: CrossSiloCommConfig + proxy_config: CrossSiloMsgConfig ) -> None: super().__init__(listen_addr, party, tls_config, proxy_config) self._grpc_options = copy.deepcopy(_DEFAULT_GRPC_CHANNEL_OPTIONS) diff --git a/tests/serializations_tests/test_unpickle_with_whitelist.py b/tests/serializations_tests/test_unpickle_with_whitelist.py index 0cf75e66..1b9eb184 100644 --- a/tests/serializations_tests/test_unpickle_with_whitelist.py +++ b/tests/serializations_tests/test_unpickle_with_whitelist.py @@ -19,7 +19,7 @@ import multiprocessing import numpy -from fed.config import CrossSiloCommConfig +from fed.config import CrossSiloMsgConfig @fed.remote @@ -53,7 +53,7 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloMsgConfig( serializing_allowed_list=allowed_list )) diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 555e7557..c7904fe8 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -19,7 +19,7 @@ import fed import fed._private.compatible_utils as compatible_utils -from fed.config import CrossSiloGrpcCommConfig +from fed.config import GrpcCrossSiloMsgConfig import signal @@ -63,7 +63,7 @@ def run(party, is_inner_party): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } - cross_silo_comm_config = CrossSiloGrpcCommConfig( + cross_silo_comm_config = GrpcCrossSiloMsgConfig( grpc_retry_policy=retry_policy, exit_on_sending_failure=True ) diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index 555b1f23..5be3ae38 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -18,7 +18,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloGrpcCommConfig +from fed.config import GrpcCrossSiloMsgConfig @fed.remote @@ -35,7 +35,7 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloGrpcCommConfig( + global_cross_silo_comm_config=GrpcCrossSiloMsgConfig( grpc_channel_options=[( 'grpc.max_send_message_length', 100 )] diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py index fe2ff786..ef485534 100644 --- a/tests/test_grpc_options_per_party.py +++ b/tests/test_grpc_options_per_party.py @@ -18,7 +18,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig +from fed.config import GrpcCrossSiloMsgConfig @fed.remote @@ -31,7 +31,7 @@ def run(party): cluster = { 'alice': { 'address': '127.0.0.1:11010', - 'cross_silo_comm_config': CrossSiloGrpcCommConfig( + 'cross_silo_comm_config': GrpcCrossSiloMsgConfig( grpc_channel_options=[ ('grpc.default_authority', 'alice'), ('grpc.max_send_message_length', 200) @@ -42,7 +42,7 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloGrpcCommConfig( + global_cross_silo_comm_config=GrpcCrossSiloMsgConfig( grpc_channel_options=[( 'grpc.max_send_message_length', 100 )] @@ -89,7 +89,7 @@ def party_grpc_options(party): cluster = { 'alice': { 'address': '127.0.0.1:11010', - 'cross_silo_comm_config': CrossSiloGrpcCommConfig( + 'cross_silo_comm_config': GrpcCrossSiloMsgConfig( grpc_channel_options=[ ('grpc.default_authority', 'alice'), ('grpc.max_send_message_length', 51 * 1024 * 1024) @@ -97,7 +97,7 @@ def party_grpc_options(party): }, 'bob': { 'address': '127.0.0.1:11011', - 'cross_silo_comm_config': CrossSiloGrpcCommConfig( + 'cross_silo_comm_config': GrpcCrossSiloMsgConfig( grpc_channel_options=[ ('grpc.default_authority', 'bob'), ('grpc.max_send_message_length', 50 * 1024 * 1024) @@ -107,7 +107,7 @@ def party_grpc_options(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloGrpcCommConfig( + global_cross_silo_comm_config=GrpcCrossSiloMsgConfig( grpc_channel_options=[( 'grpc.max_send_message_length', 100 )] diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 15f37fc0..4dcb1e3a 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -20,7 +20,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloGrpcCommConfig +from fed.config import GrpcCrossSiloMsgConfig @fed.remote @@ -53,7 +53,7 @@ def run(party, is_inner_party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloGrpcCommConfig( + global_cross_silo_comm_config=GrpcCrossSiloMsgConfig( grpc_retry_policy=retry_policy ) ) diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index 9d1ee40b..22045d14 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -20,7 +20,7 @@ import fed._private.compatible_utils as compatible_utils import ray -from fed.config import CrossSiloCommConfig +from fed.config import CrossSiloMsgConfig def run(party): @@ -65,7 +65,7 @@ def run_failure(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloCommConfig( + global_cross_silo_comm_config=CrossSiloMsgConfig( send_resource_label=send_proxy_resources, recv_resource_label=recv_proxy_resources, timeout_in_ms=10*1000, diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 27d3b27a..765529e0 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -20,7 +20,7 @@ import fed.utils as fed_utils import fed._private.compatible_utils as compatible_utils -from fed.config import CrossSiloCommConfig, CrossSiloGrpcCommConfig +from fed.config import CrossSiloMsgConfig, GrpcCrossSiloMsgConfig from fed._private import constants from fed._private import global_context from fed.proxy.barriers import ( @@ -28,7 +28,7 @@ start_recv_proxy, start_send_proxy ) -from fed.proxy.grpc_proxy import GrpcSendProxy, GrpcRecvProxy +from fed.proxy.grpc.grpc_proxy import GrpcSendProxy, GrpcRecvProxy if compatible_utils._compare_version_strings( fed_utils.get_package_version('protobuf'), '4.0.0'): from fed.grpc import fed_pb2_in_protobuf4 as fed_pb2 @@ -59,7 +59,7 @@ def test_n_to_1_transport(): SERVER_ADDRESS = "127.0.0.1:12344" party = 'test_party' cluster_config = {'test_party': {'address': SERVER_ADDRESS}} - config = CrossSiloGrpcCommConfig() + config = GrpcCrossSiloMsgConfig() start_recv_proxy( cluster_config, party, @@ -179,11 +179,11 @@ def test_send_grpc_with_meta(): constants.KEY_OF_TLS_CONFIG: "", } metadata = {"key": "value"} - send_proxy_config = CrossSiloCommConfig( + send_proxy_config = CrossSiloMsgConfig( http_header=metadata ) job_config = { - constants.KEY_OF_CROSS_SILO_COMM_CONFIG: + constants.KEY_OF_CROSS_SILO_MSG_CONFIG: send_proxy_config, } compatible_utils._init_internal_kv() @@ -205,7 +205,7 @@ def test_send_grpc_with_meta(): party, logging_level='info', proxy_cls=GrpcSendProxy, - proxy_config=CrossSiloGrpcCommConfig()) + proxy_config=GrpcCrossSiloMsgConfig()) sent_objs = [] sent_obj = send(party, "data", 0, 1) sent_objs.append(sent_obj) @@ -224,10 +224,10 @@ def test_send_grpc_with_party_specific_meta(): constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", } - send_proxy_config = CrossSiloCommConfig( + send_proxy_config = CrossSiloMsgConfig( http_header={"key": "value"}) job_config = { - constants.KEY_OF_CROSS_SILO_COMM_CONFIG: + constants.KEY_OF_CROSS_SILO_MSG_CONFIG: send_proxy_config, } compatible_utils._init_internal_kv() @@ -242,7 +242,7 @@ def test_send_grpc_with_party_specific_meta(): cluster_parties_config = { 'test_party': { 'address': SERVER_ADDRESS, - 'cross_silo_comm_config': CrossSiloCommConfig( + 'cross_silo_comm_config': CrossSiloMsgConfig( http_header={"token": "test-party-token"}) } } diff --git a/tests/test_transport_proxy_tls.py b/tests/test_transport_proxy_tls.py index 6b7a7d3f..71f10ef7 100644 --- a/tests/test_transport_proxy_tls.py +++ b/tests/test_transport_proxy_tls.py @@ -22,8 +22,8 @@ from fed._private import constants from fed._private import global_context from fed.proxy.barriers import send, start_recv_proxy, start_send_proxy -from fed.proxy.grpc_proxy import GrpcSendProxy, GrpcRecvProxy -from fed.config import CrossSiloGrpcCommConfig +from fed.proxy.grpc.grpc_proxy import GrpcSendProxy, GrpcRecvProxy +from fed.config import GrpcCrossSiloMsgConfig def test_n_to_1_transport(): @@ -57,7 +57,7 @@ def test_n_to_1_transport(): SERVER_ADDRESS = "127.0.0.1:65422" party = 'test_party' cluster_config = {'test_party': {'address': SERVER_ADDRESS}} - config = CrossSiloGrpcCommConfig() + config = GrpcCrossSiloMsgConfig() start_recv_proxy( cluster_config, party, From 0621cf9144646d71ffb8fcb4568c29f4fd55127a Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 18 Jul 2023 11:48:00 +0800 Subject: [PATCH 20/22] rename api parameter --- fed/api.py | 18 +++++++++--------- .../test_unpickle_with_whitelist.py | 2 +- tests/test_exit_on_failure_sending.py | 2 +- tests/test_grpc_options_on_proxies.py | 2 +- tests/test_grpc_options_per_party.py | 4 ++-- tests/test_retry_policy.py | 2 +- tests/test_setup_proxy_actor.py | 2 +- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/fed/api.py b/fed/api.py index fcf6f974..4ffaf3ba 100644 --- a/fed/api.py +++ b/fed/api.py @@ -50,7 +50,7 @@ def init( enable_waiting_for_other_parties_ready: bool = False, send_proxy_cls: SendProxy = None, recv_proxy_cls: RecvProxy = None, - global_cross_silo_comm_config: Optional[CrossSiloMsgConfig] = None, + global_cross_silo_msg_config: Optional[CrossSiloMsgConfig] = None, **kwargs, ): """ @@ -111,8 +111,8 @@ def init( `warning`, `error`, `critical`, not case sensititive. enable_waiting_for_other_parties_ready: ping other parties until they are all ready if True. - global_cross_silo_comm_config: Global cross-silo communication related - config that are applied to all connections. Supported configs + global_cross_silo_msg_config: Global cross-silo message related + configs that are applied to all connections. Supported configs can refer to CrossSiloMsgConfig in config.py. Examples: @@ -139,8 +139,8 @@ def init( 'cert' in tls_config and 'key' in tls_config ), 'Cert or key are not in tls_config.' - global_cross_silo_comm_config = \ - global_cross_silo_comm_config or CrossSiloMsgConfig() + global_cross_silo_msg_config = \ + global_cross_silo_msg_config or CrossSiloMsgConfig() # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv() @@ -152,7 +152,7 @@ def init( job_config = { constants.KEY_OF_CROSS_SILO_MSG_CONFIG: - global_cross_silo_comm_config, + global_cross_silo_msg_config, } compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config)) @@ -170,7 +170,7 @@ def init( logger.info(f'Started rayfed with {cluster_config}') get_global_context().get_cleanup_manager().start( - exit_when_failure_sending=global_cross_silo_comm_config.exit_on_sending_failure) + exit_when_failure_sending=global_cross_silo_msg_config.exit_on_sending_failure) if recv_proxy_cls is None: logger.debug( @@ -184,7 +184,7 @@ def init( logging_level=logging_level, tls_config=tls_config, proxy_cls=recv_proxy_cls, - proxy_config=global_cross_silo_comm_config + proxy_config=global_cross_silo_msg_config ) if send_proxy_cls is None: @@ -198,7 +198,7 @@ def init( logging_level=logging_level, tls_config=tls_config, proxy_cls=send_proxy_cls, - proxy_config=global_cross_silo_comm_config + proxy_config=global_cross_silo_msg_config ) if enable_waiting_for_other_parties_ready: diff --git a/tests/serializations_tests/test_unpickle_with_whitelist.py b/tests/serializations_tests/test_unpickle_with_whitelist.py index 1b9eb184..ec4f8463 100644 --- a/tests/serializations_tests/test_unpickle_with_whitelist.py +++ b/tests/serializations_tests/test_unpickle_with_whitelist.py @@ -53,7 +53,7 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloMsgConfig( + global_cross_silo_msg_config=CrossSiloMsgConfig( serializing_allowed_list=allowed_list )) diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index c7904fe8..08979858 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -71,7 +71,7 @@ def run(party, is_inner_party): cluster=cluster, party=party, logging_level='debug', - global_cross_silo_comm_config=cross_silo_comm_config + global_cross_silo_msg_config=cross_silo_comm_config ) o = f.party("alice").remote() diff --git a/tests/test_grpc_options_on_proxies.py b/tests/test_grpc_options_on_proxies.py index 5be3ae38..34735b8d 100644 --- a/tests/test_grpc_options_on_proxies.py +++ b/tests/test_grpc_options_on_proxies.py @@ -35,7 +35,7 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=GrpcCrossSiloMsgConfig( + global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( grpc_channel_options=[( 'grpc.max_send_message_length', 100 )] diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py index ef485534..810c6665 100644 --- a/tests/test_grpc_options_per_party.py +++ b/tests/test_grpc_options_per_party.py @@ -42,7 +42,7 @@ def run(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=GrpcCrossSiloMsgConfig( + global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( grpc_channel_options=[( 'grpc.max_send_message_length', 100 )] @@ -107,7 +107,7 @@ def party_grpc_options(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=GrpcCrossSiloMsgConfig( + global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( grpc_channel_options=[( 'grpc.max_send_message_length', 100 )] diff --git a/tests/test_retry_policy.py b/tests/test_retry_policy.py index 4dcb1e3a..af6d9e83 100644 --- a/tests/test_retry_policy.py +++ b/tests/test_retry_policy.py @@ -53,7 +53,7 @@ def run(party, is_inner_party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=GrpcCrossSiloMsgConfig( + global_cross_silo_msg_config=GrpcCrossSiloMsgConfig( grpc_retry_policy=retry_policy ) ) diff --git a/tests/test_setup_proxy_actor.py b/tests/test_setup_proxy_actor.py index 22045d14..00215e38 100644 --- a/tests/test_setup_proxy_actor.py +++ b/tests/test_setup_proxy_actor.py @@ -65,7 +65,7 @@ def run_failure(party): fed.init( cluster=cluster, party=party, - global_cross_silo_comm_config=CrossSiloMsgConfig( + global_cross_silo_msg_config=CrossSiloMsgConfig( send_resource_label=send_proxy_resources, recv_resource_label=recv_proxy_resources, timeout_in_ms=10*1000, From 8818c8f045d9221392fef56b3dfad31e1ac153d2 Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 18 Jul 2023 14:17:04 +0800 Subject: [PATCH 21/22] rename parameter --- fed/_private/serialization_utils.py | 2 +- fed/api.py | 2 +- fed/config.py | 2 +- fed/proxy/barriers.py | 10 +++++----- fed/proxy/grpc/grpc_proxy.py | 2 +- tests/test_exit_on_failure_sending.py | 4 ++-- tests/test_grpc_options_per_party.py | 6 +++--- tests/test_transport_proxy.py | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/fed/_private/serialization_utils.py b/fed/_private/serialization_utils.py index ea7abe18..07182cf6 100644 --- a/fed/_private/serialization_utils.py +++ b/fed/_private/serialization_utils.py @@ -64,7 +64,7 @@ def _apply_loads_function_with_whitelist(): global _pickle_whitelist _pickle_whitelist = fed_config.get_job_config() \ - .cross_silo_comm_config.serializing_allowed_list + .cross_silo_msg_config.serializing_allowed_list if _pickle_whitelist is None: return diff --git a/fed/api.py b/fed/api.py index 4ffaf3ba..22dff338 100644 --- a/fed/api.py +++ b/fed/api.py @@ -67,7 +67,7 @@ def init( # (Optional) the listen address, the `address` will be # used if not provided. 'listen_addr': '0.0.0.0:10001', - 'cross_silo_comm_config': CrossSiloMsgConfig + 'cross_silo_msg_config': CrossSiloMsgConfig }, 'bob': { # The address for other parties. diff --git a/fed/config.py b/fed/config.py index 8c179adc..84af784a 100644 --- a/fed/config.py +++ b/fed/config.py @@ -39,7 +39,7 @@ def __init__(self, raw_bytes: bytes) -> None: self._data = cloudpickle.loads(raw_bytes) @property - def cross_silo_comm_config(self): + def cross_silo_msg_config(self): return self._data.get( fed_constants.KEY_OF_CROSS_SILO_MSG_CONFIG, CrossSiloMsgConfig()) diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index 47118d6f..a1a65736 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -75,9 +75,9 @@ def __init__( self._cluster = cluster self._party = party self._tls_config = tls_config - cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config + cross_silo_msg_config = fed_config.get_job_config().cross_silo_msg_config self._proxy_instance: SendProxy = proxy_cls( - cluster, party, tls_config, cross_silo_comm_config) + cluster, party, tls_config, cross_silo_msg_config) async def is_ready(self): res = await self._proxy_instance.is_ready() @@ -141,9 +141,9 @@ def __init__( self._listen_addr = listen_addr self._party = party self._tls_config = tls_config - cross_silo_comm_config = fed_config.get_job_config().cross_silo_comm_config + cross_silo_msg_config = fed_config.get_job_config().cross_silo_msg_config self._proxy_instance: RecvProxy = proxy_cls( - listen_addr, party, tls_config, cross_silo_comm_config) + listen_addr, party, tls_config, cross_silo_msg_config) async def start(self): await self._proxy_instance.start() @@ -246,7 +246,7 @@ def start_send_proxy( logging_level=logging_level, proxy_cls=proxy_cls ) - timeout = get_job_config().cross_silo_comm_config.timeout_in_ms / 1000 + timeout = get_job_config().cross_silo_msg_config.timeout_in_ms / 1000 assert ray.get(_SEND_PROXY_ACTOR.is_ready.remote(), timeout=timeout) logger.info("SendProxyActor has successfully created.") diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index f9d0e968..0735ac6e 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -143,7 +143,7 @@ def get_grpc_config_by_party(self, dest_party): grpc_options = self._grpc_options dest_party_comm_config = self._cluster[dest_party].get( - 'cross_silo_comm_config', None) + 'cross_silo_msg_config', None) if dest_party_comm_config is not None: if dest_party_comm_config.http_header is not None: dest_party_grpc_metadata = dict(dest_party_comm_config.http_header) diff --git a/tests/test_exit_on_failure_sending.py b/tests/test_exit_on_failure_sending.py index 08979858..33b3f98c 100644 --- a/tests/test_exit_on_failure_sending.py +++ b/tests/test_exit_on_failure_sending.py @@ -63,7 +63,7 @@ def run(party, is_inner_party): "backoffMultiplier": 1, "retryableStatusCodes": ["UNAVAILABLE"], } - cross_silo_comm_config = GrpcCrossSiloMsgConfig( + cross_silo_msg_config = GrpcCrossSiloMsgConfig( grpc_retry_policy=retry_policy, exit_on_sending_failure=True ) @@ -71,7 +71,7 @@ def run(party, is_inner_party): cluster=cluster, party=party, logging_level='debug', - global_cross_silo_msg_config=cross_silo_comm_config + global_cross_silo_msg_config=cross_silo_msg_config ) o = f.party("alice").remote() diff --git a/tests/test_grpc_options_per_party.py b/tests/test_grpc_options_per_party.py index 810c6665..95b1465c 100644 --- a/tests/test_grpc_options_per_party.py +++ b/tests/test_grpc_options_per_party.py @@ -31,7 +31,7 @@ def run(party): cluster = { 'alice': { 'address': '127.0.0.1:11010', - 'cross_silo_comm_config': GrpcCrossSiloMsgConfig( + 'cross_silo_msg_config': GrpcCrossSiloMsgConfig( grpc_channel_options=[ ('grpc.default_authority', 'alice'), ('grpc.max_send_message_length', 200) @@ -89,7 +89,7 @@ def party_grpc_options(party): cluster = { 'alice': { 'address': '127.0.0.1:11010', - 'cross_silo_comm_config': GrpcCrossSiloMsgConfig( + 'cross_silo_msg_config': GrpcCrossSiloMsgConfig( grpc_channel_options=[ ('grpc.default_authority', 'alice'), ('grpc.max_send_message_length', 51 * 1024 * 1024) @@ -97,7 +97,7 @@ def party_grpc_options(party): }, 'bob': { 'address': '127.0.0.1:11011', - 'cross_silo_comm_config': GrpcCrossSiloMsgConfig( + 'cross_silo_msg_config': GrpcCrossSiloMsgConfig( grpc_channel_options=[ ('grpc.default_authority', 'bob'), ('grpc.max_send_message_length', 50 * 1024 * 1024) diff --git a/tests/test_transport_proxy.py b/tests/test_transport_proxy.py index 765529e0..54aa3c24 100644 --- a/tests/test_transport_proxy.py +++ b/tests/test_transport_proxy.py @@ -242,7 +242,7 @@ def test_send_grpc_with_party_specific_meta(): cluster_parties_config = { 'test_party': { 'address': SERVER_ADDRESS, - 'cross_silo_comm_config': CrossSiloMsgConfig( + 'cross_silo_msg_config': CrossSiloMsgConfig( http_header={"token": "test-party-token"}) } } From 922cd5bd34227879813ba816d26ec91024c2944f Mon Sep 17 00:00:00 2001 From: paer Date: Tue, 18 Jul 2023 14:30:33 +0800 Subject: [PATCH 22/22] rename parameter --- fed/proxy/grpc/grpc_proxy.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index 0735ac6e..2502c8f4 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -142,16 +142,16 @@ def get_grpc_config_by_party(self, dest_party): grpc_metadata = self._grpc_metadata grpc_options = self._grpc_options - dest_party_comm_config = self._cluster[dest_party].get( + dest_party_msg_config = self._cluster[dest_party].get( 'cross_silo_msg_config', None) - if dest_party_comm_config is not None: - if dest_party_comm_config.http_header is not None: - dest_party_grpc_metadata = dict(dest_party_comm_config.http_header) + if dest_party_msg_config is not None: + if dest_party_msg_config.http_header is not None: + dest_party_grpc_metadata = dict(dest_party_msg_config.http_header) grpc_metadata = { **grpc_metadata, **dest_party_grpc_metadata } - dest_party_grpc_options = parse_grpc_options(dest_party_comm_config) + dest_party_grpc_options = parse_grpc_options(dest_party_msg_config) grpc_options = { **grpc_options, **dest_party_grpc_options }