Skip to content

Commit 0ed3938

Browse files
Properly buffer async response bodies
1 parent 0e79346 commit 0ed3938

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

packages/smithy-http/src/smithy_http/aio/protocols.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
2+
from collections.abc import AsyncIterable
23
from inspect import iscoroutinefunction
3-
from io import BytesIO
44
from typing import Any
55

6-
from smithy_core.aio.interfaces import ClientProtocol
6+
from smithy_core.aio.interfaces import AsyncByteStream, ClientProtocol
7+
from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob
78
from smithy_core.codecs import Codec
89
from smithy_core.deserializers import DeserializeableShape
910
from smithy_core.documents import TypeRegistry
@@ -109,35 +110,45 @@ async def deserialize_response[
109110
error_registry: TypeRegistry,
110111
context: TypedProperties,
111112
) -> OperationOutput:
112-
body = response.body
113-
114-
# if body is not streaming and is async, we have to buffer it
115-
if not operation.output_stream_member and not is_streaming_blob(body):
116-
if (
117-
read := getattr(body, "read", None)
118-
) is not None and iscoroutinefunction(read):
119-
body = BytesIO(await read())
120-
121113
if not self._is_success(operation, context, response):
122114
raise await self._create_error(
123115
operation=operation,
124116
request=request,
125117
response=response,
126-
response_body=body, # type: ignore
118+
response_body=await self._buffer_async_body(response.body),
127119
error_registry=error_registry,
128120
context=context,
129121
)
130122

123+
# if body is not streaming and is async, we have to buffer it
124+
body: SyncStreamingBlob | None = None
125+
if not operation.output_stream_member and not is_streaming_blob(body):
126+
body = await self._buffer_async_body(response.body)
127+
131128
# TODO(optimization): response binding cache like done in SJ
132129
deserializer = HTTPResponseDeserializer(
133130
payload_codec=self.payload_codec,
134131
http_trait=operation.schema.expect_trait(HTTPTrait),
135132
response=response,
136-
body=body, # type: ignore
133+
body=body,
137134
)
138135

139136
return operation.output.deserialize(deserializer)
140137

138+
async def _buffer_async_body(self, stream: AsyncStreamingBlob) -> SyncStreamingBlob:
139+
match stream:
140+
case AsyncByteStream():
141+
if not iscoroutinefunction(stream.read):
142+
return stream # type: ignore
143+
return await stream.read()
144+
case AsyncIterable():
145+
full = b""
146+
async for chunk in stream:
147+
full += chunk
148+
return full
149+
case _:
150+
return stream
151+
141152
def _is_success(
142153
self,
143154
operation: APIOperation[Any, Any],

0 commit comments

Comments
 (0)