diff --git a/fbpcs/private_computation/service/mpc/mpc.py b/fbpcs/private_computation/service/mpc/mpc.py index 324c7bb17..23f96262e 100644 --- a/fbpcs/private_computation/service/mpc/mpc.py +++ b/fbpcs/private_computation/service/mpc/mpc.py @@ -111,6 +111,7 @@ def create_instance( num_workers: int, server_ips: Optional[List[str]] = None, game_args: Optional[List[Dict[str, Any]]] = None, + server_uris: Optional[List[str]] = None, ) -> MPCInstance: self.logger.info(f"Creating MPC instance: {instance_id}") @@ -123,7 +124,7 @@ def create_instance( [], MPCInstanceStatus.CREATED, game_args, - [], + server_uris, ) self.instance_repository.create(instance) diff --git a/fbpcs/private_computation/test/service/mpc/test_mpc.py b/fbpcs/private_computation/test/service/mpc/test_mpc.py index a746099e8..02c671e43 100644 --- a/fbpcs/private_computation/test/service/mpc/test_mpc.py +++ b/fbpcs/private_computation/test/service/mpc/test_mpc.py @@ -38,6 +38,10 @@ "concurrency": TEST_CONCURRENCY_ARGS, } ] +TEST_SERVER_URIS = [ + "node1.publisher.study1.pci.facebook.com", + "node2.publisher.study1.pci.facebook.com", +] class TestMPCService(IsolatedAsyncioTestCase): @@ -62,7 +66,7 @@ def setUp(self): ) @staticmethod - def _get_sample_mpcinstance(): + def _get_sample_mpcinstance() -> MPCInstance: return MPCInstance( TEST_INSTANCE_ID, TEST_GAME_NAME, @@ -72,11 +76,11 @@ def _get_sample_mpcinstance(): [], MPCInstanceStatus.CREATED, GAME_ARGS, - [], + TEST_SERVER_URIS, ) @staticmethod - def _get_sample_mpcinstance_with_game_args(): + def _get_sample_mpcinstance_with_game_args() -> MPCInstance: return MPCInstance( TEST_INSTANCE_ID, TEST_GAME_NAME, @@ -86,11 +90,11 @@ def _get_sample_mpcinstance_with_game_args(): [], MPCInstanceStatus.CREATED, GAME_ARGS, - [], + TEST_SERVER_URIS, ) @staticmethod - def _get_sample_mpcinstance_client(): + def _get_sample_mpcinstance_client() -> MPCInstance: return MPCInstance( TEST_INSTANCE_ID, TEST_GAME_NAME, @@ -100,10 +104,10 @@ def _get_sample_mpcinstance_client(): [], MPCInstanceStatus.CREATED, GAME_ARGS, - [], + TEST_SERVER_URIS, ) - async def test_spin_up_containers_onedocker_inconsistent_arguments(self): + async def test_spin_up_containers_onedocker_inconsistent_arguments(self) -> None: with self.assertRaisesRegex( ValueError, "The number of containers is not consistent with the number of game argument dictionary.", @@ -126,7 +130,7 @@ async def test_spin_up_containers_onedocker_inconsistent_arguments(self): ip_addresses=TEST_SERVER_IPS, ) - def test_create_instance_with_game_args(self): + def test_create_instance_with_game_args(self) -> None: self.mpc_service.create_instance( instance_id=TEST_INSTANCE_ID, game_name=TEST_GAME_NAME, @@ -134,14 +138,17 @@ def test_create_instance_with_game_args(self): num_workers=TEST_NUM_WORKERS, server_ips=TEST_SERVER_IPS, game_args=GAME_ARGS, + server_uris=TEST_SERVER_URIS, ) + # pyre-ignore [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `assert_called`. self.mpc_service.instance_repository.create.assert_called() self.assertEqual( self._get_sample_mpcinstance_with_game_args(), + # pyre-ignore [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `call_args`. self.mpc_service.instance_repository.create.call_args[0][0], ) - def test_create_instance(self): + def test_create_instance(self) -> None: self.mpc_service.create_instance( instance_id=TEST_INSTANCE_ID, game_name=TEST_GAME_NAME, @@ -149,22 +156,25 @@ def test_create_instance(self): num_workers=TEST_NUM_WORKERS, server_ips=TEST_SERVER_IPS, game_args=GAME_ARGS, + server_uris=TEST_SERVER_URIS, ) # check that instance with correct instance_id was created + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `assert_called`. self.mpc_service.instance_repository.create.assert_called() self.assertEqual( self._get_sample_mpcinstance(), + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.create` has no attribute `call_args`. self.mpc_service.instance_repository.create.call_args[0][0], ) - def _read_side_effect_start(self, instance_id: str): + def _read_side_effect_start(self, instance_id: str) -> MPCInstance: """mock MPCInstanceRepository.read for test_start""" if instance_id == TEST_INSTANCE_ID: return self._get_sample_mpcinstance() else: raise RuntimeError(f"{instance_id} does not exist") - def test_start_instance(self): + def test_start_instance(self) -> None: self.mpc_service.instance_repository.read = MagicMock( side_effect=self._read_side_effect_start ) @@ -187,12 +197,14 @@ def test_start_instance(self): ) # check that update is called with correct status self.mpc_service.start_instance(TEST_INSTANCE_ID) + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`. self.mpc_service.instance_repository.update.assert_called() + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `call_args_list`. latest_update = self.mpc_service.instance_repository.update.call_args_list[-1] updated_status = latest_update[0][0].status self.assertEqual(updated_status, MPCInstanceStatus.STARTED) - def test_start_instance_missing_ips(self): + def test_start_instance_missing_ips(self) -> None: self.mpc_service.instance_repository.read = MagicMock( return_value=self._get_sample_mpcinstance_client() ) @@ -200,7 +212,7 @@ def test_start_instance_missing_ips(self): with self.assertRaises(ValueError): self.mpc_service.start_instance(TEST_INSTANCE_ID) - def test_start_instance_skip_start_up(self): + def test_start_instance_skip_start_up(self) -> None: # prep self.mpc_service.instance_repository.read = MagicMock( side_effect=self._read_side_effect_start @@ -226,12 +238,14 @@ def test_start_instance_skip_start_up(self): ) # asserts self.mpc_service.onedocker_svc.wait_for_pending_containers.assert_not_called() + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`. self.mpc_service.instance_repository.update.assert_called() + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `call_args_list`. latest_update = self.mpc_service.instance_repository.update.call_args_list[-1] updated_status = latest_update[0][0].status self.assertEqual(updated_status, MPCInstanceStatus.CREATED) - def _read_side_effect_update(self, instance_id): + def _read_side_effect_update(self, instance_id) -> MPCInstance: """ mock MPCInstanceRepository.read for test_update, with instance.containers is not None @@ -251,7 +265,7 @@ def _read_side_effect_update(self, instance_id): ] return mpc_instance - def test_update_instance(self): + def test_update_instance(self) -> None: self.mpc_service.instance_repository.read = MagicMock( side_effect=self._read_side_effect_update ) @@ -266,9 +280,10 @@ def test_update_instance(self): return_value=container_instances ) self.mpc_service.update_instance(TEST_INSTANCE_ID) + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`. self.mpc_service.instance_repository.update.assert_called() - def test_update_instance_from_unknown(self): + def test_update_instance_from_unknown(self) -> None: # prep mpc_instance = self._get_sample_mpcinstance() mpc_instance.containers = [ @@ -292,18 +307,21 @@ def test_update_instance_from_unknown(self): # act self.mpc_service.update_instance(TEST_INSTANCE_ID) # asserts + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called`. self.mpc_service.instance_repository.update.assert_called() + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `call_args_list`. latest_update = self.mpc_service.instance_repository.update.call_args_list[-1] updated_instance = latest_update[0][0] self.assertEqual(updated_instance.status, MPCInstanceStatus.STARTED) self.assertEqual(updated_instance.server_ips, ["10.0.1.130"]) - def test_stop_instance(self): + def test_stop_instance(self) -> None: self.mpc_service.instance_repository.read = MagicMock( side_effect=self._read_side_effect_update ) self.mpc_service.onedocker_svc.stop_containers = MagicMock(return_value=[None]) mpc_instance = self.mpc_service.stop_instance(TEST_INSTANCE_ID) + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.service.onedocker.OneDockerService.stop_containers` has no attribute `assert_called_with`. self.mpc_service.onedocker_svc.stop_containers.assert_called_with( [ "arn:aws:ecs:us-west-1:592513842793:task/57850450-7a81-43cc-8c73-2071c52e4a68" @@ -312,11 +330,12 @@ def test_stop_instance(self): expected_mpc_instance = self._read_side_effect_update(TEST_INSTANCE_ID) expected_mpc_instance.status = MPCInstanceStatus.CANCELED self.assertEqual(expected_mpc_instance, mpc_instance) + # pyre-ignore Undefined attribute [16]: Callable `fbpcp.repository.mpc_instance.MPCInstanceRepository.update` has no attribute `assert_called_with`. self.mpc_service.instance_repository.update.assert_called_with( expected_mpc_instance ) - def test_get_updated_instance(self): + def test_get_updated_instance(self) -> None: # Arrange queried_container = ContainerInstance( TEST_INSTANCE_ID, # noqa