Skip to content

Commit

Permalink
Pass server_uris in when creating MPCInstance (facebookresearch#1913)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#1913

X-link: facebookresearch/fbpcp#459

As title. server_uris will be passed into MPCInstance upon creation on the publisher side. We need to store server_uris in the MPCInstance to be propagated back to the partner side.

Reviewed By: joe1234wu

Differential Revision: D40997128

fbshipit-source-id: f1b1927b777d5ded5c8b75d930fe24bbd2f5a43a
  • Loading branch information
YigeZhu authored and facebook-github-bot committed Nov 11, 2022
1 parent d7114ca commit c7121df
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
3 changes: 2 additions & 1 deletion fbpcs/private_computation/service/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def create_instance(
num_workers: int,
server_ips: Optional[List[str]] = None,
game_args: Optional[List[Dict[str, Any]]] = None,
server_uris: Optional[List[str]] = None,
) -> MPCInstance:
self.logger.info(f"Creating MPC instance: {instance_id}")

Expand All @@ -123,7 +124,7 @@ def create_instance(
[],
MPCInstanceStatus.CREATED,
game_args,
[],
server_uris,
)

self.instance_repository.create(instance)
Expand Down
55 changes: 37 additions & 18 deletions fbpcs/private_computation/test/service/mpc/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
"concurrency": TEST_CONCURRENCY_ARGS,
}
]
TEST_SERVER_URIS = [
"node1.publisher.study1.pci.facebook.com",
"node2.publisher.study1.pci.facebook.com",
]


class TestMPCService(IsolatedAsyncioTestCase):
Expand All @@ -62,7 +66,7 @@ def setUp(self):
)

@staticmethod
def _get_sample_mpcinstance():
def _get_sample_mpcinstance() -> MPCInstance:
return MPCInstance(
TEST_INSTANCE_ID,
TEST_GAME_NAME,
Expand All @@ -72,11 +76,11 @@ def _get_sample_mpcinstance():
[],
MPCInstanceStatus.CREATED,
GAME_ARGS,
[],
TEST_SERVER_URIS,
)

@staticmethod
def _get_sample_mpcinstance_with_game_args():
def _get_sample_mpcinstance_with_game_args() -> MPCInstance:
return MPCInstance(
TEST_INSTANCE_ID,
TEST_GAME_NAME,
Expand All @@ -86,11 +90,11 @@ def _get_sample_mpcinstance_with_game_args():
[],
MPCInstanceStatus.CREATED,
GAME_ARGS,
[],
TEST_SERVER_URIS,
)

@staticmethod
def _get_sample_mpcinstance_client():
def _get_sample_mpcinstance_client() -> MPCInstance:
return MPCInstance(
TEST_INSTANCE_ID,
TEST_GAME_NAME,
Expand All @@ -100,10 +104,10 @@ def _get_sample_mpcinstance_client():
[],
MPCInstanceStatus.CREATED,
GAME_ARGS,
[],
TEST_SERVER_URIS,
)

async def test_spin_up_containers_onedocker_inconsistent_arguments(self):
async def test_spin_up_containers_onedocker_inconsistent_arguments(self) -> None:
with self.assertRaisesRegex(
ValueError,
"The number of containers is not consistent with the number of game argument dictionary.",
Expand All @@ -126,45 +130,51 @@ async def test_spin_up_containers_onedocker_inconsistent_arguments(self):
ip_addresses=TEST_SERVER_IPS,
)

def test_create_instance_with_game_args(self):
def test_create_instance_with_game_args(self) -> None:
self.mpc_service.create_instance(
instance_id=TEST_INSTANCE_ID,
game_name=TEST_GAME_NAME,
mpc_party=TEST_MPC_ROLE,
num_workers=TEST_NUM_WORKERS,
server_ips=TEST_SERVER_IPS,
game_args=GAME_ARGS,
server_uris=TEST_SERVER_URIS,
)
# pyre-ignore [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `assert_called`.
self.mpc_service.instance_repository.create.assert_called()
self.assertEqual(
self._get_sample_mpcinstance_with_game_args(),
# pyre-ignore [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `call_args`.
self.mpc_service.instance_repository.create.call_args[0][0],
)

def test_create_instance(self):
def test_create_instance(self) -> None:
self.mpc_service.create_instance(
instance_id=TEST_INSTANCE_ID,
game_name=TEST_GAME_NAME,
mpc_party=TEST_MPC_ROLE,
num_workers=TEST_NUM_WORKERS,
server_ips=TEST_SERVER_IPS,
game_args=GAME_ARGS,
server_uris=TEST_SERVER_URIS,
)
# check that instance with correct instance_id was created
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `assert_called`.
self.mpc_service.instance_repository.create.assert_called()
self.assertEqual(
self._get_sample_mpcinstance(),
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `call_args`.
self.mpc_service.instance_repository.create.call_args[0][0],
)

def _read_side_effect_start(self, instance_id: str):
def _read_side_effect_start(self, instance_id: str) -> MPCInstance:
"""mock MPCInstanceRepository.read for test_start"""
if instance_id == TEST_INSTANCE_ID:
return self._get_sample_mpcinstance()
else:
raise RuntimeError(f"{instance_id} does not exist")

def test_start_instance(self):
def test_start_instance(self) -> None:
self.mpc_service.instance_repository.read = MagicMock(
side_effect=self._read_side_effect_start
)
Expand All @@ -187,20 +197,22 @@ def test_start_instance(self):
)
# check that update is called with correct status
self.mpc_service.start_instance(TEST_INSTANCE_ID)
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`.
self.mpc_service.instance_repository.update.assert_called()
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `call_args_list`.
latest_update = self.mpc_service.instance_repository.update.call_args_list[-1]
updated_status = latest_update[0][0].status
self.assertEqual(updated_status, MPCInstanceStatus.STARTED)

def test_start_instance_missing_ips(self):
def test_start_instance_missing_ips(self) -> None:
self.mpc_service.instance_repository.read = MagicMock(
return_value=self._get_sample_mpcinstance_client()
)
# Exception because role is client but server ips are not given
with self.assertRaises(ValueError):
self.mpc_service.start_instance(TEST_INSTANCE_ID)

def test_start_instance_skip_start_up(self):
def test_start_instance_skip_start_up(self) -> None:
# prep
self.mpc_service.instance_repository.read = MagicMock(
side_effect=self._read_side_effect_start
Expand All @@ -226,12 +238,14 @@ def test_start_instance_skip_start_up(self):
)
# asserts
self.mpc_service.onedocker_svc.wait_for_pending_containers.assert_not_called()
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`.
self.mpc_service.instance_repository.update.assert_called()
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `call_args_list`.
latest_update = self.mpc_service.instance_repository.update.call_args_list[-1]
updated_status = latest_update[0][0].status
self.assertEqual(updated_status, MPCInstanceStatus.CREATED)

def _read_side_effect_update(self, instance_id):
def _read_side_effect_update(self, instance_id) -> MPCInstance:
"""
mock MPCInstanceRepository.read for test_update,
with instance.containers is not None
Expand All @@ -251,7 +265,7 @@ def _read_side_effect_update(self, instance_id):
]
return mpc_instance

def test_update_instance(self):
def test_update_instance(self) -> None:
self.mpc_service.instance_repository.read = MagicMock(
side_effect=self._read_side_effect_update
)
Expand All @@ -266,9 +280,10 @@ def test_update_instance(self):
return_value=container_instances
)
self.mpc_service.update_instance(TEST_INSTANCE_ID)
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`.
self.mpc_service.instance_repository.update.assert_called()

def test_update_instance_from_unknown(self):
def test_update_instance_from_unknown(self) -> None:
# prep
mpc_instance = self._get_sample_mpcinstance()
mpc_instance.containers = [
Expand All @@ -292,18 +307,21 @@ def test_update_instance_from_unknown(self):
# act
self.mpc_service.update_instance(TEST_INSTANCE_ID)
# asserts
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`.
self.mpc_service.instance_repository.update.assert_called()
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `call_args_list`.
latest_update = self.mpc_service.instance_repository.update.call_args_list[-1]
updated_instance = latest_update[0][0]
self.assertEqual(updated_instance.status, MPCInstanceStatus.STARTED)
self.assertEqual(updated_instance.server_ips, ["10.0.1.130"])

def test_stop_instance(self):
def test_stop_instance(self) -> None:
self.mpc_service.instance_repository.read = MagicMock(
side_effect=self._read_side_effect_update
)
self.mpc_service.onedocker_svc.stop_containers = MagicMock(return_value=[None])
mpc_instance = self.mpc_service.stop_instance(TEST_INSTANCE_ID)
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.service.onedocker.OneDockerService.stop_containers` has no attribute `assert_called_with`.
self.mpc_service.onedocker_svc.stop_containers.assert_called_with(
[
"arn:aws:ecs:us-west-1:592513842793:task/57850450-7a81-43cc-8c73-2071c52e4a68"
Expand All @@ -312,11 +330,12 @@ def test_stop_instance(self):
expected_mpc_instance = self._read_side_effect_update(TEST_INSTANCE_ID)
expected_mpc_instance.status = MPCInstanceStatus.CANCELED
self.assertEqual(expected_mpc_instance, mpc_instance)
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called_with`.
self.mpc_service.instance_repository.update.assert_called_with(
expected_mpc_instance
)

def test_get_updated_instance(self):
def test_get_updated_instance(self) -> None:
# Arrange
queried_container = ContainerInstance(
TEST_INSTANCE_ID, # noqa
Expand Down

0 comments on commit c7121df

Please sign in to comment.