Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass server_uris in when creating MPCInstance #459

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fbpcp/service/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def create_instance(
num_workers: int,
server_ips: Optional[List[str]] = None,
game_args: Optional[List[Dict[str, Any]]] = None,
server_uris: Optional[List[str]] = None,
) -> MPCInstance:
self.logger.info(f"Creating MPC instance: {instance_id}")

Expand All @@ -117,7 +118,7 @@ def create_instance(
[],
MPCInstanceStatus.CREATED,
game_args,
[],
server_uris,
)

self.instance_repository.create(instance)
Expand Down
55 changes: 37 additions & 18 deletions tests/service/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
"concurrency": TEST_CONCURRENCY_ARGS,
}
]
TEST_SERVER_URIS = [
"node1.publisher.study1.pci.facebook.com",
"node2.publisher.study1.pci.facebook.com",
]


class TestMPCService(IsolatedAsyncioTestCase):
Expand All @@ -54,7 +58,7 @@ def setUp(self):
)

@staticmethod
def _get_sample_mpcinstance():
def _get_sample_mpcinstance() -> MPCInstance:
return MPCInstance(
TEST_INSTANCE_ID,
TEST_GAME_NAME,
Expand All @@ -64,11 +68,11 @@ def _get_sample_mpcinstance():
[],
MPCInstanceStatus.CREATED,
GAME_ARGS,
[],
TEST_SERVER_URIS,
)

@staticmethod
def _get_sample_mpcinstance_with_game_args():
def _get_sample_mpcinstance_with_game_args() -> MPCInstance:
return MPCInstance(
TEST_INSTANCE_ID,
TEST_GAME_NAME,
Expand All @@ -78,11 +82,11 @@ def _get_sample_mpcinstance_with_game_args():
[],
MPCInstanceStatus.CREATED,
GAME_ARGS,
[],
TEST_SERVER_URIS,
)

@staticmethod
def _get_sample_mpcinstance_client():
def _get_sample_mpcinstance_client() -> MPCInstance:
return MPCInstance(
TEST_INSTANCE_ID,
TEST_GAME_NAME,
Expand All @@ -92,10 +96,10 @@ def _get_sample_mpcinstance_client():
[],
MPCInstanceStatus.CREATED,
GAME_ARGS,
[],
TEST_SERVER_URIS,
)

async def test_spin_up_containers_onedocker_inconsistent_arguments(self):
async def test_spin_up_containers_onedocker_inconsistent_arguments(self) -> None:
with self.assertRaisesRegex(
ValueError,
"The number of containers is not consistent with the number of game argument dictionary.",
Expand All @@ -118,45 +122,51 @@ async def test_spin_up_containers_onedocker_inconsistent_arguments(self):
ip_addresses=TEST_SERVER_IPS,
)

def test_create_instance_with_game_args(self):
def test_create_instance_with_game_args(self) -> None:
self.mpc_service.create_instance(
instance_id=TEST_INSTANCE_ID,
game_name=TEST_GAME_NAME,
mpc_party=TEST_MPC_ROLE,
num_workers=TEST_NUM_WORKERS,
server_ips=TEST_SERVER_IPS,
game_args=GAME_ARGS,
server_uris=TEST_SERVER_URIS,
)
# pyre-ignore [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `assert_called`.
self.mpc_service.instance_repository.create.assert_called()
self.assertEqual(
self._get_sample_mpcinstance_with_game_args(),
# pyre-ignore [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `call_args`.
self.mpc_service.instance_repository.create.call_args[0][0],
)

def test_create_instance(self):
def test_create_instance(self) -> None:
self.mpc_service.create_instance(
instance_id=TEST_INSTANCE_ID,
game_name=TEST_GAME_NAME,
mpc_party=TEST_MPC_ROLE,
num_workers=TEST_NUM_WORKERS,
server_ips=TEST_SERVER_IPS,
game_args=GAME_ARGS,
server_uris=TEST_SERVER_URIS,
)
# check that instance with correct instance_id was created
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `assert_called`.
self.mpc_service.instance_repository.create.assert_called()
self.assertEqual(
self._get_sample_mpcinstance(),
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `call_args`.
self.mpc_service.instance_repository.create.call_args[0][0],
)

def _read_side_effect_start(self, instance_id: str):
def _read_side_effect_start(self, instance_id: str) -> MPCInstance:
"""mock MPCInstanceRepository.read for test_start"""
if instance_id == TEST_INSTANCE_ID:
return self._get_sample_mpcinstance()
else:
raise RuntimeError(f"{instance_id} does not exist")

def test_start_instance(self):
def test_start_instance(self) -> None:
self.mpc_service.instance_repository.read = MagicMock(
side_effect=self._read_side_effect_start
)
Expand All @@ -179,20 +189,22 @@ def test_start_instance(self):
)
# check that update is called with correct status
self.mpc_service.start_instance(TEST_INSTANCE_ID)
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`.
self.mpc_service.instance_repository.update.assert_called()
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `call_args_list`.
latest_update = self.mpc_service.instance_repository.update.call_args_list[-1]
updated_status = latest_update[0][0].status
self.assertEqual(updated_status, MPCInstanceStatus.STARTED)

def test_start_instance_missing_ips(self):
def test_start_instance_missing_ips(self) -> None:
self.mpc_service.instance_repository.read = MagicMock(
return_value=self._get_sample_mpcinstance_client()
)
# Exception because role is client but server ips are not given
with self.assertRaises(ValueError):
self.mpc_service.start_instance(TEST_INSTANCE_ID)

def test_start_instance_skip_start_up(self):
def test_start_instance_skip_start_up(self) -> None:
# prep
self.mpc_service.instance_repository.read = MagicMock(
side_effect=self._read_side_effect_start
Expand All @@ -218,12 +230,14 @@ def test_start_instance_skip_start_up(self):
)
# asserts
self.mpc_service.onedocker_svc.wait_for_pending_containers.assert_not_called()
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`.
self.mpc_service.instance_repository.update.assert_called()
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `call_args_list`.
latest_update = self.mpc_service.instance_repository.update.call_args_list[-1]
updated_status = latest_update[0][0].status
self.assertEqual(updated_status, MPCInstanceStatus.CREATED)

def _read_side_effect_update(self, instance_id):
def _read_side_effect_update(self, instance_id) -> MPCInstance:
"""
mock MPCInstanceRepository.read for test_update,
with instance.containers is not None
Expand All @@ -243,7 +257,7 @@ def _read_side_effect_update(self, instance_id):
]
return mpc_instance

def test_update_instance(self):
def test_update_instance(self) -> None:
self.mpc_service.instance_repository.read = MagicMock(
side_effect=self._read_side_effect_update
)
Expand All @@ -258,9 +272,10 @@ def test_update_instance(self):
return_value=container_instances
)
self.mpc_service.update_instance(TEST_INSTANCE_ID)
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`.
self.mpc_service.instance_repository.update.assert_called()

def test_update_instance_from_unknown(self):
def test_update_instance_from_unknown(self) -> None:
# prep
mpc_instance = self._get_sample_mpcinstance()
mpc_instance.containers = [
Expand All @@ -284,18 +299,21 @@ def test_update_instance_from_unknown(self):
# act
self.mpc_service.update_instance(TEST_INSTANCE_ID)
# asserts
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`.
self.mpc_service.instance_repository.update.assert_called()
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `call_args_list`.
latest_update = self.mpc_service.instance_repository.update.call_args_list[-1]
updated_instance = latest_update[0][0]
self.assertEqual(updated_instance.status, MPCInstanceStatus.STARTED)
self.assertEqual(updated_instance.server_ips, ["10.0.1.130"])

def test_stop_instance(self):
def test_stop_instance(self) -> None:
self.mpc_service.instance_repository.read = MagicMock(
side_effect=self._read_side_effect_update
)
self.mpc_service.onedocker_svc.stop_containers = MagicMock(return_value=[None])
mpc_instance = self.mpc_service.stop_instance(TEST_INSTANCE_ID)
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.service.onedocker.OneDockerService.stop_containers` has no attribute `assert_called_with`.
self.mpc_service.onedocker_svc.stop_containers.assert_called_with(
[
"arn:aws:ecs:us-west-1:592513842793:task/57850450-7a81-43cc-8c73-2071c52e4a68"
Expand All @@ -304,11 +322,12 @@ def test_stop_instance(self):
expected_mpc_instance = self._read_side_effect_update(TEST_INSTANCE_ID)
expected_mpc_instance.status = MPCInstanceStatus.CANCELED
self.assertEqual(expected_mpc_instance, mpc_instance)
# pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called_with`.
self.mpc_service.instance_repository.update.assert_called_with(
expected_mpc_instance
)

def test_get_updated_instance(self):
def test_get_updated_instance(self) -> None:
# Arrange
queried_container = ContainerInstance(
TEST_INSTANCE_ID, # noqa
Expand Down