|
1 | 1 | import os
|
| 2 | +from collections.abc import AsyncIterable |
2 | 3 | from inspect import iscoroutinefunction
|
3 |
| -from io import BytesIO |
4 | 4 | from typing import Any
|
5 | 5 |
|
6 |
| -from smithy_core.aio.interfaces import ClientProtocol |
| 6 | +from smithy_core.aio.interfaces import AsyncByteStream, ClientProtocol |
| 7 | +from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob |
7 | 8 | from smithy_core.codecs import Codec
|
8 | 9 | from smithy_core.deserializers import DeserializeableShape
|
9 | 10 | from smithy_core.documents import TypeRegistry
|
@@ -109,35 +110,45 @@ async def deserialize_response[
|
109 | 110 | error_registry: TypeRegistry,
|
110 | 111 | context: TypedProperties,
|
111 | 112 | ) -> OperationOutput:
|
112 |
| - body = response.body |
113 |
| - |
114 |
| - # if body is not streaming and is async, we have to buffer it |
115 |
| - if not operation.output_stream_member and not is_streaming_blob(body): |
116 |
| - if ( |
117 |
| - read := getattr(body, "read", None) |
118 |
| - ) is not None and iscoroutinefunction(read): |
119 |
| - body = BytesIO(await read()) |
120 |
| - |
121 | 113 | if not self._is_success(operation, context, response):
|
122 | 114 | raise await self._create_error(
|
123 | 115 | operation=operation,
|
124 | 116 | request=request,
|
125 | 117 | response=response,
|
126 |
| - response_body=body, # type: ignore |
| 118 | + response_body=await self._buffer_async_body(response.body), |
127 | 119 | error_registry=error_registry,
|
128 | 120 | context=context,
|
129 | 121 | )
|
130 | 122 |
|
| 123 | + # if body is not streaming and is async, we have to buffer it |
| 124 | + body: SyncStreamingBlob | None = None |
| 125 | + if not operation.output_stream_member and not is_streaming_blob(body): |
| 126 | + body = await self._buffer_async_body(response.body) |
| 127 | + |
131 | 128 | # TODO(optimization): response binding cache like done in SJ
|
132 | 129 | deserializer = HTTPResponseDeserializer(
|
133 | 130 | payload_codec=self.payload_codec,
|
134 | 131 | http_trait=operation.schema.expect_trait(HTTPTrait),
|
135 | 132 | response=response,
|
136 |
| - body=body, # type: ignore |
| 133 | + body=body, |
137 | 134 | )
|
138 | 135 |
|
139 | 136 | return operation.output.deserialize(deserializer)
|
140 | 137 |
|
| 138 | + async def _buffer_async_body(self, stream: AsyncStreamingBlob) -> SyncStreamingBlob: |
| 139 | + match stream: |
| 140 | + case AsyncByteStream(): |
| 141 | + if not iscoroutinefunction(stream.read): |
| 142 | + return stream # type: ignore |
| 143 | + return await stream.read() |
| 144 | + case AsyncIterable(): |
| 145 | + full = b"" |
| 146 | + async for chunk in stream: |
| 147 | + full += chunk |
| 148 | + return full |
| 149 | + case _: |
| 150 | + return stream |
| 151 | + |
141 | 152 | def _is_success(
|
142 | 153 | self,
|
143 | 154 | operation: APIOperation[Any, Any],
|
|
0 commit comments