From 848c6c186331e6d4584c6edf647e100c50a5fcae Mon Sep 17 00:00:00 2001 From: Robert Bartel Date: Mon, 6 Nov 2023 10:40:22 -0500 Subject: [PATCH] Fix client package clients to use connection context. Updating classes to properly use async context management in transport client when appropriate. --- .../lib/client/dmod/client/request_clients.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/python/lib/client/dmod/client/request_clients.py b/python/lib/client/dmod/client/request_clients.py index 248e45223..773301cbc 100644 --- a/python/lib/client/dmod/client/request_clients.py +++ b/python/lib/client/dmod/client/request_clients.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from dmod.communication import (AuthClient, InvalidMessageResponse, ManagementAction, NGENRequest, NGENRequestResponse, NgenCalibrationRequest, NgenCalibrationResponse, TransportLayerClient) +from dmod.communication.client import ConnectionContextClient from dmod.communication.dataset_management_message import DatasetManagementMessage, DatasetManagementResponse, \ MaaSDatasetManagementMessage, MaaSDatasetManagementResponse, QueryType, DatasetQuery from dmod.communication.data_transmit_message import DataTransmitMessage, DataTransmitResponse @@ -25,8 +26,14 @@ def __init__(self, transport_client: TransportLayerClient, auth_client: AuthClie async def _submit_job_request(self, request) -> str: if await self._auth_client.apply_auth(request): - await self._transport_client.async_send(data=str(request)) - return await self._transport_client.async_recv() + # Some clients may be async context managers + if isinstance(self._transport_client, ConnectionContextClient): + async with self._transport_client as t_client: + await t_client.async_send(data=str(request)) + return await t_client.async_recv() + else: + await self._transport_client.async_send(data=str(request)) + return await self._transport_client.async_recv() else: msg = f"{self.__class__.__name__} could not use {self._auth_client.__class__.__name__} to authenticate " \ f"{request.__class__.__name__}" @@ -506,8 +513,15 @@ async def _process_request(self, request: DatasetManagementMessage) -> DatasetMa msg = f'{self.__class__.__name__} create_dataset could not apply auth to {request.__class__.__name__}' return response_type(success=False, reason=reason, message=msg) - await self._transport_client.async_send(data=str(request)) - response_data = await self._transport_client.async_recv() + # Some clients may be async context managers + if isinstance(self._transport_client, ConnectionContextClient): + async with self._transport_client as t_client: + await t_client.async_send(data=str(request)) + response_data = await t_client.async_recv() + else: + await self._transport_client.async_send(data=str(request)) + response_data = await self._transport_client.async_recv() + response_obj = response_type.factory_init_from_deserialized_json(json.loads(response_data)) if not isinstance(response_obj, response_type): msg = f"{self.__class__.__name__} could not deserialize {response_type.__name__} from raw response data" \