diff --git a/reportportal_client/aio/client.py b/reportportal_client/aio/client.py index 88e36f7..eaf7ac4 100644 --- a/reportportal_client/aio/client.py +++ b/reportportal_client/aio/client.py @@ -1351,6 +1351,9 @@ def __init_loop(self, loop: Optional[asyncio.AbstractEventLoop] = None): daemon=True) self._thread.start() + async def __return_value(self, value): + return value + def __init__( self, endpoint: str, @@ -1358,6 +1361,7 @@ def __init__( *, task_timeout: float = DEFAULT_TASK_TIMEOUT, shutdown_timeout: float = DEFAULT_SHUTDOWN_TIMEOUT, + launch_uuid: Optional[Union[str, Task[str]]] = None, task_list: Optional[BackgroundTaskList[Task[_T]]] = None, task_mutex: Optional[threading.RLock] = None, loop: Optional[asyncio.AbstractEventLoop] = None, @@ -1399,11 +1403,15 @@ def __init__( :param loop: Event Loop which is used to process Tasks. The Client creates own one if this argument is None. """ - super().__init__(endpoint, project, **kwargs) self.task_timeout = task_timeout self.shutdown_timeout = shutdown_timeout self.__init_task_list(task_list, task_mutex) self.__init_loop(loop) + if type(launch_uuid) == str: + super().__init__(endpoint, project, + launch_uuid=self.create_task(self.__return_value(launch_uuid)), **kwargs) + else: + super().__init__(endpoint, project, launch_uuid=launch_uuid, **kwargs) def create_task(self, coro: Coroutine[Any, Any, _T]) -> Optional[Task[_T]]: """Create a Task from given Coroutine. @@ -1518,6 +1526,9 @@ def __init_loop(self, loop: Optional[asyncio.AbstractEventLoop] = None): self._loop = asyncio.new_event_loop() self._loop.set_task_factory(BatchedTaskFactory()) + async def __return_value(self, value): + return value + def __init__( self, endpoint: str, @@ -1525,6 +1536,7 @@ def __init__( *, task_timeout: float = DEFAULT_TASK_TIMEOUT, shutdown_timeout: float = DEFAULT_SHUTDOWN_TIMEOUT, + launch_uuid: Optional[Union[str, Task[str]]] = None, task_list: Optional[TriggerTaskBatcher] = None, task_mutex: Optional[threading.RLock] = None, loop: Optional[asyncio.AbstractEventLoop] = None, @@ -1570,7 +1582,6 @@ def __init__( :param trigger_num: Number of tasks which triggers Task batch execution. :param trigger_interval: Time limit which triggers Task batch execution. """ - super().__init__(endpoint, project, **kwargs) self.task_timeout = task_timeout self.shutdown_timeout = shutdown_timeout self.trigger_num = trigger_num @@ -1578,6 +1589,11 @@ def __init__( self.__init_task_list(task_list, task_mutex) self.__last_run_time = datetime.time() self.__init_loop(loop) + if type(launch_uuid) == str: + super().__init__(endpoint, project, + launch_uuid=self.create_task(self.__return_value(launch_uuid)), **kwargs) + else: + super().__init__(endpoint, project, launch_uuid=launch_uuid, **kwargs) def create_task(self, coro: Coroutine[Any, Any, _T]) -> Optional[Task[_T]]: """Create a Task from given Coroutine. diff --git a/tests/aio/test_batched_client.py b/tests/aio/test_batched_client.py index a9502f0..9962a11 100644 --- a/tests/aio/test_batched_client.py +++ b/tests/aio/test_batched_client.py @@ -11,8 +11,14 @@ # See the License for the specific language governing permissions and # limitations under the License import pickle +import sys +from unittest import mock + +# noinspection PyPackageRequirements +import pytest from reportportal_client.aio import BatchedRPClient +from reportportal_client.helpers import timestamp def test_batched_rp_client_pickling(): @@ -60,7 +66,7 @@ def test_clone(): ) assert ( cloned.client.api_key == kwargs['api_key'] - and cloned.launch_uuid == kwargs['launch_uuid'] + and cloned.launch_uuid.blocking_result() == kwargs['launch_uuid'] and cloned.log_batch_size == kwargs['log_batch_size'] and cloned.log_batch_payload_limit == kwargs['log_batch_payload_limit'] and cloned.task_timeout == kwargs['task_timeout'] @@ -70,3 +76,64 @@ def test_clone(): ) assert cloned._item_stack.qsize() == 1 \ and async_client.current_item() == cloned.current_item() + + +@pytest.mark.skipif(sys.version_info < (3, 8), + reason='the test requires AsyncMock which was introduced in Python 3.8') +@pytest.mark.parametrize( + 'launch_uuid, method, params', + [ + ('test_launch_uuid', 'start_test_item', ['Test Item', timestamp(), 'STEP']), + ('test_launch_uuid', 'finish_test_item', ['test_item_id', timestamp()]), + ('test_launch_uuid', 'get_launch_info', []), + ('test_launch_uuid', 'get_launch_ui_id', []), + ('test_launch_uuid', 'get_launch_ui_url', []), + ('test_launch_uuid', 'log', [timestamp(), 'Test message']), + (None, 'start_test_item', ['Test Item', timestamp(), 'STEP']), + (None, 'finish_test_item', ['test_item_id', timestamp()]), + (None, 'get_launch_info', []), + (None, 'get_launch_ui_id', []), + (None, 'get_launch_ui_url', []), + (None, 'log', [timestamp(), 'Test message']), + ] +) +def test_launch_uuid_usage(launch_uuid, method, params): + started_launch_uuid = 'new_test_launch_uuid' + aio_client = mock.AsyncMock() + aio_client.start_launch.return_value = started_launch_uuid + client = BatchedRPClient('http://endpoint', 'project', api_key='api_key', + client=aio_client, launch_uuid=launch_uuid, log_batch_size=1) + actual_launch_uuid = (client.start_launch('Test Launch', timestamp())).blocking_result() + getattr(client, method)(*params).blocking_result() + finish_launch_message = (client.finish_launch(timestamp())).blocking_result() + + if launch_uuid is None: + aio_client.start_launch.assert_called_once() + assert actual_launch_uuid == started_launch_uuid + assert client.launch_uuid.blocking_result() == started_launch_uuid + aio_client.finish_launch.assert_called_once() + assert finish_launch_message + else: + aio_client.start_launch.assert_not_called() + assert actual_launch_uuid == launch_uuid + assert client.launch_uuid.blocking_result() == launch_uuid + aio_client.finish_launch.assert_not_called() + assert finish_launch_message == '' + assert client.launch_uuid.blocking_result() == actual_launch_uuid + + if method == 'log': + assert len(getattr(aio_client, 'log_batch').call_args_list) == 2 + args, kwargs = getattr(aio_client, 'log_batch').call_args_list[0] + batch = args[0] + assert isinstance(batch, list) + assert len(batch) == 1 + log = batch[0] + assert log.launch_uuid.blocking_result() == actual_launch_uuid + assert log.time == params[0] + assert log.message == params[1] + else: + getattr(aio_client, method).assert_called_once() + args, kwargs = getattr(aio_client, method).call_args_list[0] + assert args[0].blocking_result() == actual_launch_uuid + for i, param in enumerate(params): + assert args[i + 1] == param diff --git a/tests/aio/test_threaded_client.py b/tests/aio/test_threaded_client.py index caf2694..cc0334c 100644 --- a/tests/aio/test_threaded_client.py +++ b/tests/aio/test_threaded_client.py @@ -59,7 +59,7 @@ def test_clone(): ) assert ( cloned.client.api_key == kwargs['api_key'] - and cloned.launch_uuid == kwargs['launch_uuid'] + and cloned.launch_uuid.blocking_result() == kwargs['launch_uuid'] and cloned.log_batch_size == kwargs['log_batch_size'] and cloned.log_batch_payload_limit == kwargs['log_batch_payload_limit'] and cloned.task_timeout == kwargs['task_timeout']