Skip to content

Commit

Permalink
Peek stream at the end of the event streaming to sync concurrent back…
Browse files Browse the repository at this point in the history
…buffer
  • Loading branch information
SergeyRyabinin committed Feb 18, 2025
1 parent 874da8a commit de758bc
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 62 deletions.
55 changes: 25 additions & 30 deletions src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,37 +411,32 @@ int CurlHttpClient::CurlProgressCallback(void *userdata, double, double, double,
CurlReadCallbackContext* context = reinterpret_cast<CurlReadCallbackContext*>(userdata);

const std::shared_ptr<Aws::IOStream>& ioStream = context->m_request->GetContentBody();
if (ioStream->eof())
{
curl_easy_pause(context->m_curlHandle, CURLPAUSE_CONT);
return 0;
}

if (context->m_client->m_perfMode == TransferLibPerformanceMode::LOW_LATENCY)
{
// forcing "underflow" on the IOStream with ConcurrentStreamBuf to move data from back buffer to put area
int peekVal = ioStream->peek();
AWS_UNREFERENCED_PARAM(peekVal);

// forcing curl to try to ReadBody again (~to poll body IOStream for HTTP2)
// This is a spin pause-unpause in case of no data provided by a customer callback
// But otherwise curl will slow down the transfer and start calling as at frequency of 1s
// see https://curl.se/mail/lib-2020-07/0046.html
// we should use multi handle or another HTTP client in the future to avoid this
if (!ioStream || ioStream->bad()) {
AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Input stream is bad!");
return 1; // libcurl abort the transfer and return CURLE_ABORTED_BY_CALLBACK.
}

const int peekVal = ioStream->peek();
if (ioStream->eof() && peekVal == std::char_traits<char>::eof()) {
// curl won't call ReadBody after the last ReadBody call returns 0
// however, this Progress method is still called few times for incoming data.
return 0;
}

if (context->m_client->m_perfMode == TransferLibPerformanceMode::LOW_LATENCY) {
AWS_UNREFERENCED_PARAM(peekVal);
// forcing curl to try to ReadBody again (~to poll body IOStream for HTTP2)
// This is a spin pause-unpause in case of no data provided by a customer callback
// But otherwise curl will slow down the transfer and start calling as at frequency of 1s
// see https://curl.se/mail/lib-2020-07/0046.html
// we should use multi handle or another HTTP client in the future to avoid this
curl_easy_pause(context->m_curlHandle, CURLPAUSE_CONT);
} else {
if (peekVal != std::char_traits<char>::eof()) {
curl_easy_pause(context->m_curlHandle, CURLPAUSE_CONT);
}
else
{
char output[1];
if (ioStream->readsome(output, 1) > 0)
{
ioStream->unget();
if (!ioStream->good())
{
AWS_LOGSTREAM_WARN(CURL_HTTP_CLIENT_TAG, "Input stream failed to perform unget().");
}
curl_easy_pause(context->m_curlHandle, CURLPAUSE_CONT);
}
} else {
curl_easy_pause(context->m_curlHandle, CURLPAUSE_SEND);
}
}

return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ add_project(aws-cpp-sdk-transcribestreaming-integ-tests
testing-resources
aws-cpp-sdk-core)

set(AWS_TRANSCRIBESTREAMING_INTEGRATION_TESTS_SRC "RunTests.cpp" "TranscribeTests.cpp")
set(AWS_TRANSCRIBESTREAMING_INTEGRATION_TESTS_SRC "RunTests.cpp" "TranscribeTests.cpp" "TranscribeErrorCaseTests.cpp")

add_definitions(-DRESOURCES_DIR="${CMAKE_CURRENT_SOURCE_DIR}/resources")

Expand Down
56 changes: 48 additions & 8 deletions tests/aws-cpp-sdk-transcribestreaming-integ-tests/RunTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,58 @@
* SPDX-License-Identifier: Apache-2.0.
*/

#include <gtest/gtest.h>
#include <aws/core/Aws.h>
#include <aws/core/internal/AWSHttpResourceClient.h>
#include <aws/core/platform/FileSystem.h>
#include <aws/testing/TestingEnvironment.h>
#include <gtest/gtest.h>

#include <fstream>

void DownloadTestStreams() {
Aws::SDKOptions options = {};
options.httpOptions.installSigPipeHandler = true;
Aws::InitAPI(options);
{
Aws::Client::ClientConfigurationInitValues cfgInit;
cfgInit.shouldDisableIMDS = true;
Aws::Internal::AWSHttpResourceClient httpclient(cfgInit);
const Aws::Vector<Aws::String> TEST_FILE_NAMES = {"transcribe-test-file.wav", "this_is_a_cpp_test_sample_8kHz_2162ms.wav",
"Kant_16kHz_17176ms.wav"};
for (const auto& toDownload : TEST_FILE_NAMES) {
Aws::String path = "aws-sdk-cpp-builds-sdks-team/test-resources/" + toDownload;
auto payload = httpclient.GetResource("https://s3.amazonaws.com/" /*endpoint*/, path.c_str(), "" /*authToken*/);
Aws::OFStream testFile(toDownload.c_str(), std::ios_base::out | std::ios_base::binary | std::ios_base::trunc);
testFile.write(payload.c_str(), payload.size());
testFile.close();
}
}
Aws::ShutdownAPI(options);
}

void Cleanup() {
Aws::SDKOptions options = {};
options.httpOptions.installSigPipeHandler = true;
Aws::InitAPI(options);
{
const Aws::Vector<Aws::String> TEST_FILE_NAMES = {"transcribe-test-file.wav", "this_is_a_cpp_test_sample_8kHz_2162ms.wav",
"Kant_16kHz_17176ms.wav"};
for (const auto& testFilename : TEST_FILE_NAMES) {
Aws::FileSystem::RemoveFileIfExists(testFilename.c_str());
}
Aws::ShutdownAPI(options);
}
}

int main(int argc, char** argv)
{
Aws::Testing::SetDefaultSigPipeHandler();
int main(int argc, char** argv) {
Aws::Testing::SetDefaultSigPipeHandler();

Aws::Testing::ParseArgs(argc, argv);
Aws::Testing::ParseArgs(argc, argv);

::testing::InitGoogleTest(&argc, argv);
int exitCode = RUN_ALL_TESTS();
::testing::InitGoogleTest(&argc, argv);
DownloadTestStreams();
int exitCode = RUN_ALL_TESTS();
Cleanup();

return exitCode;
return exitCode;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

// WinHTTP SDK client does not work with streaming
#if defined(HAVE_H2_CLIENT) && !defined(_WIN32)
#include <aws/core/Aws.h>
#include <aws/testing/AwsCppSdkGTestSuite.h>
#include <aws/testing/AwsTestHelpers.h>
#include <aws/transcribestreaming/TranscribeStreamingServiceClient.h>
#include <aws/transcribestreaming/model/StartStreamTranscriptionHandler.h>
#include <aws/transcribestreaming/model/StartStreamTranscriptionRequest.h>

#include <fstream>

using namespace Aws;
using namespace Aws::Utils;
using namespace Aws::TranscribeStreamingService;
using namespace Aws::TranscribeStreamingService::Model;

static const char ALLOC_TAG[] = "TranscribeStreamingErrorTests";
static const char TEST_FILE_NAME[] = "transcribe-test-file.wav";

#define AWS_ADD_FAILURE(MSG) \
ADD_FAILURE() << MSG; \
TestTraceFailure(MSG);

class TranscribeStreamingErrorTests : public Aws::Testing::AwsCppSdkGTestSuite {
public:
TranscribeStreamingErrorTests() = default;

protected:
Aws::String m_testTraces;
void TestTrace(Aws::String msg) { m_testTraces += Aws::Utils::DateTime::Now().ToGmtStringWithMs() + " " + std::move(msg) + "\n"; };
void TestTraceFailure(Aws::String msg) {
m_testTraces += "[FAILURE] " + Aws::Utils::DateTime::Now().ToGmtStringWithMs() + " " + std::move(msg) + "\n";
};
void SetUp() {}
void TearDown() {
if (::testing::Test::HasFailure()) {
std::cout << "Test traces: " << m_testTraces << "\n";
}
m_testTraces.erase();
}
};

TEST_F(TranscribeStreamingErrorTests, TranscribeAudioFile) {
Aws::Client::ClientConfigurationInitValues cfgInit;
cfgInit.shouldDisableIMDS = true;
Aws::Client::ClientConfiguration config(cfgInit);
config.httpLibPerfMode = Http::TransferLibPerformanceMode::REGULAR;
config.enableHttpClientTrace = true;
Aws::UniquePtr<TranscribeStreamingServiceClient> client = Aws::MakeUnique<TranscribeStreamingServiceClient>(ALLOC_TAG, config);

Aws::String transcribedResult;
StartStreamTranscriptionHandler handler;
handler.SetTranscriptEventCallback([&transcribedResult](const TranscriptEvent& ev) {
const auto& results = ev.GetTranscript().GetResults();
if (results.empty()) {
return;
}
const auto& last = results.back();
const auto& alternatives = last.GetAlternatives();
if (alternatives.empty()) {
return;
}
transcribedResult = alternatives.back().GetTranscript();
});

Aws::String operationRequestId;
handler.SetInitialResponseCallback([&](const StartStreamTranscriptionInitialResponse& initialResponse) {
operationRequestId = initialResponse.GetRequestId();
if (operationRequestId.empty()) {
AWS_ADD_FAILURE("InitialResponseCallback is called but received empty RequestId");
TestTrace(Aws::String("initialResponse was: ") + initialResponse.Jsonize().View().AsString());
}
TestTrace(Aws::String("InitialResponse aws RequestId: ") + operationRequestId);
TestTrace(Aws::String("InitialResponse transcribe SessionId: ") + initialResponse.GetSessionId());
});
handler.SetOnErrorCallback([&transcribedResult, this](const Aws::Client::AWSError<TranscribeStreamingServiceErrors>& errors) {
// we will receive an error because the request was abruptly shutdown (via stream.Close()).
// However, we cannot delay the call to closing the stream, because HTTP clients such as libcurl buffer the
// requests before sending them over the wire, so it will keep BUFFER_SIZE bytes in its memory if we don't
// signal the completion of the stream (by closing it).
// To discern between that case and a true error, we check if we have received any text back from the service.
if (transcribedResult.empty()) {
AWS_ADD_FAILURE(Aws::String("Received error: ") + errors.GetExceptionName() + ". with msg: " + errors.GetMessage())
}
});

StartStreamTranscriptionRequest request;
request.SetMediaSampleRateHertz(8000);
request.SetLanguageCode(LanguageCode::en_US);
request.SetMediaEncoding(MediaEncoding::pcm);
request.SetEventStreamHandler(handler);

auto OnStreamReady = [this](AudioStream& stream) {
TestTrace(Aws::String("OnStreamReady"));

for (size_t i = 0; i < 2; ++i) {
Aws::FStream file(TEST_FILE_NAME, std::ios_base::in | std::ios_base::binary);
ASSERT_TRUE(file);
char buf[2048];
while (file) {
file.read(buf, sizeof(buf));
Aws::Vector<unsigned char> bits{buf, buf + file.gcount()};
AudioEvent event(std::move(bits));
if (!stream) {
break;
}
TestTrace(Aws::String("Writing good event"));
if (!stream.WriteAudioEvent(event)) {
AWS_ADD_FAILURE("Failed to write an audio event");
break;
}
}
if (i == 0) {
TestTrace(Aws::String("Writing large bad event"));
Aws::Utils::Event::Message msg;
msg.InsertEventHeader(":message-type", Aws::String("event"));
msg.InsertEventHeader(":event-type", Aws::String("AudioEvent"));
msg.InsertEventHeader(":content-type", Aws::String("application/octet-stream"));
msg.WriteEventPayload([]() {
Aws::String payload;
for (size_t j = 0; j < 3000; ++j) {
payload += "very bad event ";
}
return payload;
}());
stream.WriteEvent(msg);
if (!stream) {
AWS_ADD_FAILURE("Failed to write a bad event");
break;
}
}
}

TestTrace(Aws::String("Sending final empty frame"));
stream.WriteAudioEvent({});
TestTrace(Aws::String("Flushing and closing the stream"));
stream.flush();
stream.Close();
};

Aws::Utils::Threading::Semaphore semaphore(0, 1);
auto OnResponseCallback = [&semaphore](const TranscribeStreamingServiceClient*, const StartStreamTranscriptionRequest&,
const StartStreamTranscriptionOutcome&,
const std::shared_ptr<const Aws::Client::AsyncCallerContext>&) { semaphore.ReleaseAll(); };

client->StartStreamTranscriptionAsync(request, OnStreamReady, OnResponseCallback, nullptr /*context*/);
semaphore.WaitOne();

EXPECT_FALSE(operationRequestId.empty()) << "Did not receive a request id for the StartStreamTranscription";
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,6 @@ class TranscribeStreamingTests : public Aws::Testing::AwsCppSdkGTestSuite
Aws::Client::ClientConfigurationInitValues cfgInit;
cfgInit.shouldDisableIMDS = true;
Aws::Client::ClientConfiguration config(cfgInit);

Aws::Internal::AWSHttpResourceClient httpclient(cfgInit);
const Aws::Vector<Aws::String> TEST_FILE_NAMES = {"transcribe-test-file.wav", "this_is_a_cpp_test_sample_8kHz_2162ms.wav", "Kant_16kHz_17176ms.wav"};
for(const auto& toDownload : TEST_FILE_NAMES)
{
Aws::String path = "aws-sdk-cpp-builds-sdks-team/test-resources/" + toDownload;
auto payload = httpclient.GetResource("https://s3.amazonaws.com/" /*endpoint*/,
path.c_str(), "" /*authToken*/);
Aws::OFStream testFile(toDownload.c_str(), std::ios_base::out | std::ios_base::binary | std::ios_base::trunc);
testFile.write(payload.c_str(), payload.size());
testFile.close();
}

config.enableHttpClientTrace = true;
#ifdef _WIN32
// TODO: remove this once we get H2 working with WinHttp client
Expand All @@ -65,10 +52,7 @@ class TranscribeStreamingTests : public Aws::Testing::AwsCppSdkGTestSuite
m_clientWithWrongEndpoint = Aws::MakeUnique<TranscribeStreamingServiceClient>(ALLOC_TAG, config);
}

~TranscribeStreamingTests()
{
Aws::FileSystem::RemoveFileIfExists(TEST_FILE_NAME);
}
~TranscribeStreamingTests() override = default;

Aws::UniquePtr<TranscribeStreamingServiceClient> m_clientWithWrongCreds;
Aws::UniquePtr<TranscribeStreamingServiceClient> m_clientWithWrongEndpoint;
Expand All @@ -83,10 +67,10 @@ class TranscribeStreamingTests : public Aws::Testing::AwsCppSdkGTestSuite
{
m_testTraces += "[FAILURE] " + Aws::Utils::DateTime::Now().ToGmtStringWithMs() + " " + std::move(msg) + "\n";
};
void SetUp()
void SetUp() override
{
}
void TearDown()
void TearDown() override
{
if (::testing::Test::HasFailure())
{
Expand Down Expand Up @@ -611,8 +595,7 @@ TEST_F(TranscribeStreamingTests, TranscribeStreamingCppSdkSample)
Aws::UniquePtr<TranscribeStreamingServiceClient> client = Aws::MakeUnique<TranscribeStreamingServiceClient>(ALLOC_TAG, config);

const Aws::Vector<Aws::String> EXPECTED_ALTERNATIVES = {"This is a C plus plus test sample", "This is a C++ test sample"};
for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 50)
{
for (size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 75) {
m_testTraces.clear();
TestTrace(Aws::String("### Starting TranscribeStreamingCppSdkSample with chunks of ") + Aws::Utils::StringUtils::to_string(chunkDuration) + " ms ##");
int64_t startedAt = Aws::Utils::DateTime::Now().Millis();
Expand Down Expand Up @@ -641,8 +624,7 @@ TEST_F(TranscribeStreamingTests, TranscribeStreamingKantSample)
static const char expected[] = "Categorical imperative: Act only according to that maxim whereby you can at the same time will that it should become a universal law. "
"Two things fill the mind with ever-increasing wonder and awe, the more often and the more intensely the mind of thought is drawn to them: "
"the starry heavens above me and the moral law within me.";
for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 50)
{
for (size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 75) {
m_testTraces.clear();
TestTrace(Aws::String("### Starting TranscribeStreamingKantSample with chunks of ") + Aws::Utils::StringUtils::to_string(chunkDuration) + " ms ##");
int64_t startedAt = Aws::Utils::DateTime::Now().Millis();
Expand Down

0 comments on commit de758bc

Please sign in to comment.