diff --git a/include/aws/http/private/h1_stream.h b/include/aws/http/private/h1_stream.h index df1446ec..8b210c8b 100644 --- a/include/aws/http/private/h1_stream.h +++ b/include/aws/http/private/h1_stream.h @@ -117,6 +117,7 @@ struct aws_h1_stream *aws_h1_stream_new_request( struct aws_h1_stream *aws_h1_stream_new_request_handler(const struct aws_http_request_handler_options *options); int aws_h1_stream_activate(struct aws_http_stream *stream); +void aws_h1_stream_cancel(struct aws_http_stream *stream, int error_code); int aws_h1_stream_send_response(struct aws_h1_stream *stream, struct aws_http_message *response); diff --git a/include/aws/http/private/request_response_impl.h b/include/aws/http/private/request_response_impl.h index 0a620b7f..acc5d9dd 100644 --- a/include/aws/http/private/request_response_impl.h +++ b/include/aws/http/private/request_response_impl.h @@ -17,6 +17,7 @@ struct aws_http_stream_vtable { void (*destroy)(struct aws_http_stream *stream); void (*update_window)(struct aws_http_stream *stream, size_t increment_size); int (*activate)(struct aws_http_stream *stream); + void (*cancel)(struct aws_http_stream *stream, int error_code); int (*http1_write_chunk)(struct aws_http_stream *http1_stream, const struct aws_http1_chunk_options *options); int (*http1_add_trailer)(struct aws_http_stream *http1_stream, const struct aws_http_headers *trailing_headers); diff --git a/include/aws/http/request_response.h b/include/aws/http/request_response.h index 878dbaf3..3f25fdd2 100644 --- a/include/aws/http/request_response.h +++ b/include/aws/http/request_response.h @@ -1109,6 +1109,18 @@ void aws_http_stream_update_window(struct aws_http_stream *stream, size_t increm AWS_HTTP_API uint32_t aws_http_stream_get_id(const struct aws_http_stream *stream); +/** + * Cancel the stream in flight. + * For HTTP/1.1 streams, it's equivalent to closing the connection. + * For HTTP/2 streams, it's equivalent to calling reset on the stream with `AWS_HTTP2_ERR_CANCEL`. + * + * the stream will complete with the error code provided, unless the stream is + * already completing for other reasons, or the stream is not activated, + * in which case this call will have no impact. + */ +AWS_HTTP_API +void aws_http_stream_cancel(struct aws_http_stream *stream, int error_code); + /** * Reset the HTTP/2 stream (HTTP/2 only). * Note that if the stream closes before this async call is fully processed, the RST_STREAM frame will not be sent. diff --git a/source/h1_connection.c b/source/h1_connection.c index d04e844a..903cf038 100644 --- a/source/h1_connection.c +++ b/source/h1_connection.c @@ -388,6 +388,34 @@ int aws_h1_stream_activate(struct aws_http_stream *stream) { return AWS_OP_SUCCESS; } +void aws_h1_stream_cancel(struct aws_http_stream *stream, int error_code) { + struct aws_h1_stream *h1_stream = AWS_CONTAINER_OF(stream, struct aws_h1_stream, base); + struct aws_http_connection *base_connection = stream->owning_connection; + struct aws_h1_connection *connection = AWS_CONTAINER_OF(base_connection, struct aws_h1_connection, base); + + { /* BEGIN CRITICAL SECTION */ + aws_h1_connection_lock_synced_data(connection); + if (h1_stream->synced_data.api_state != AWS_H1_STREAM_API_STATE_ACTIVE || + connection->synced_data.is_open == false) { + /* Not active, nothing to cancel. */ + aws_h1_connection_unlock_synced_data(connection); + AWS_LOGF_DEBUG(AWS_LS_HTTP_STREAM, "id=%p: Stream not active, nothing to cancel.", (void *)stream); + return; + } + + aws_h1_connection_unlock_synced_data(connection); + } /* END CRITICAL SECTION */ + AWS_LOGF_INFO( + AWS_LS_HTTP_CONNECTION, + "id=%p: Connection shutting down due to stream=%p cancelled with error code %d (%s).", + (void *)&connection->base, + (void *)stream, + error_code, + aws_error_name(error_code)); + + s_stop(connection, false /*stop_reading*/, false /*stop_writing*/, true /*schedule_shutdown*/, error_code); +} + struct aws_http_stream *s_make_request( struct aws_http_connection *client_connection, const struct aws_http_make_request_options *options) { diff --git a/source/h1_stream.c b/source/h1_stream.c index 0066ff69..ef8f086b 100644 --- a/source/h1_stream.c +++ b/source/h1_stream.c @@ -329,6 +329,7 @@ static const struct aws_http_stream_vtable s_stream_vtable = { .destroy = s_stream_destroy, .update_window = s_stream_update_window, .activate = aws_h1_stream_activate, + .cancel = aws_h1_stream_cancel, .http1_write_chunk = s_stream_write_chunk, .http1_add_trailer = s_stream_add_trailer, .http2_reset_stream = NULL, diff --git a/source/h2_stream.c b/source/h2_stream.c index 3ce23752..6d8336bf 100644 --- a/source/h2_stream.c +++ b/source/h2_stream.c @@ -27,12 +27,17 @@ static int s_stream_write_data( static void s_stream_cross_thread_work_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); static struct aws_h2err s_send_rst_and_close_stream(struct aws_h2_stream *stream, struct aws_h2err stream_error); -static int s_stream_reset_stream_internal(struct aws_http_stream *stream_base, struct aws_h2err stream_error); +static int s_stream_reset_stream_internal( + struct aws_http_stream *stream_base, + struct aws_h2err stream_error, + bool cancelling); +static void s_stream_cancel(struct aws_http_stream *stream, int error_code); struct aws_http_stream_vtable s_h2_stream_vtable = { .destroy = s_stream_destroy, .update_window = s_stream_update_window, .activate = aws_h2_stream_activate, + .cancel = s_stream_cancel, .http1_write_chunk = NULL, .http2_reset_stream = s_stream_reset_stream, .http2_get_received_error_code = s_stream_get_received_error_code, @@ -526,12 +531,16 @@ static void s_stream_update_window(struct aws_http_stream *stream_base, size_t i .h2_code = AWS_HTTP2_ERR_INTERNAL_ERROR, }; /* Only when stream is not initialized reset will fail. So, we can assert it to be succeed. */ - AWS_FATAL_ASSERT(s_stream_reset_stream_internal(stream_base, stream_error) == AWS_OP_SUCCESS); + AWS_FATAL_ASSERT( + s_stream_reset_stream_internal(stream_base, stream_error, false /*cancelling*/) == AWS_OP_SUCCESS); } return; } -static int s_stream_reset_stream_internal(struct aws_http_stream *stream_base, struct aws_h2err stream_error) { +static int s_stream_reset_stream_internal( + struct aws_http_stream *stream_base, + struct aws_h2err stream_error, + bool cancelling) { struct aws_h2_stream *stream = AWS_CONTAINER_OF(stream_base, struct aws_h2_stream, base); struct aws_h2_connection *connection = s_get_h2_connection(stream); @@ -553,21 +562,25 @@ static int s_stream_reset_stream_internal(struct aws_http_stream *stream_base, s } /* END CRITICAL SECTION */ if (stream_is_init) { + if (cancelling) { + /* Not an error if we are just cancelling. */ + AWS_LOGF_DEBUG(AWS_LS_HTTP_STREAM, "id=%p: Stream not in process, nothing to cancel.", (void *)stream); + return AWS_OP_SUCCESS; + } AWS_H2_STREAM_LOG( ERROR, stream, "Reset stream failed. Stream is in initialized state, please activate the stream first."); return aws_raise_error(AWS_ERROR_INVALID_STATE); } + if (reset_called) { + AWS_H2_STREAM_LOG(DEBUG, stream, "Reset stream ignored. Reset stream has been called already."); + } + if (cross_thread_work_should_schedule) { AWS_H2_STREAM_LOG(TRACE, stream, "Scheduling stream cross-thread work task"); /* increment the refcount of stream to keep it alive until the task runs */ aws_atomic_fetch_add(&stream->base.refcount, 1); aws_channel_schedule_task_now(connection->base.channel_slot->channel, &stream->cross_thread_work_task); - return AWS_OP_SUCCESS; } - if (reset_called) { - AWS_H2_STREAM_LOG(DEBUG, stream, "Reset stream ignored. Reset stream has been called already."); - } - return AWS_OP_SUCCESS; } @@ -583,7 +596,16 @@ static int s_stream_reset_stream(struct aws_http_stream *stream_base, uint32_t h (void *)stream_base, aws_http2_error_code_to_str(http2_error), http2_error); - return s_stream_reset_stream_internal(stream_base, stream_error); + return s_stream_reset_stream_internal(stream_base, stream_error, false /*cancelling*/); +} + +void s_stream_cancel(struct aws_http_stream *stream_base, int error_code) { + struct aws_h2err stream_error = { + .aws_code = error_code, + .h2_code = AWS_HTTP2_ERR_CANCEL, + }; + s_stream_reset_stream_internal(stream_base, stream_error, true /*cancelling*/); + return; } static int s_stream_get_received_error_code(struct aws_http_stream *stream_base, uint32_t *out_http2_error) { diff --git a/source/request_response.c b/source/request_response.c index f76115f2..dbd5214e 100644 --- a/source/request_response.c +++ b/source/request_response.c @@ -1201,6 +1201,10 @@ uint32_t aws_http_stream_get_id(const struct aws_http_stream *stream) { return stream->id; } +void aws_http_stream_cancel(struct aws_http_stream *stream, int error_code) { + stream->vtable->cancel(stream, error_code); +} + int aws_http2_stream_reset(struct aws_http_stream *http2_stream, uint32_t http2_error) { AWS_PRECONDITION(http2_stream); AWS_PRECONDITION(http2_stream->vtable); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7e61445d..06dc389f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -143,6 +143,7 @@ add_test_case(h1_client_switching_protocols_fails_pending_requests) add_test_case(h1_client_switching_protocols_fails_subsequent_requests) add_test_case(h1_client_switching_protocols_requires_downstream_handler) add_test_case(h1_client_connection_close_before_request_finishes) +add_test_case(h1_client_stream_cancel) add_test_case(h1_client_response_close_connection_before_request_finishes) add_test_case(h1_client_response_first_byte_timeout_connection) add_test_case(h1_client_response_first_byte_timeout_request_override) @@ -475,6 +476,7 @@ add_test_case(h2_client_conn_failed_initial_settings_completed_not_invoked) add_test_case(h2_client_stream_reset_stream) add_test_case(h2_client_stream_reset_ignored_stream_closed) add_test_case(h2_client_stream_reset_failed_before_activate_called) +add_test_case(h2_client_stream_cancel_stream) add_test_case(h2_client_stream_keeps_alive_for_cross_thread_task) add_test_case(h2_client_stream_get_received_reset_error_code) add_test_case(h2_client_stream_get_sent_reset_error_code) diff --git a/tests/test_h1_client.c b/tests/test_h1_client.c index de7751a8..19cf1f97 100644 --- a/tests/test_h1_client.c +++ b/tests/test_h1_client.c @@ -4236,6 +4236,62 @@ H1_CLIENT_TEST_CASE(h1_client_connection_close_before_request_finishes) { return AWS_OP_SUCCESS; } +H1_CLIENT_TEST_CASE(h1_client_stream_cancel) { + (void)ctx; + struct tester tester; + ASSERT_SUCCESS(s_tester_init(&tester, allocator)); + + /* set up request whose body won't send immediately */ + struct slow_body_sender body_sender; + AWS_ZERO_STRUCT(body_sender); + s_slow_body_sender_init(&body_sender); + struct aws_input_stream *body_stream = &body_sender.base; + + struct aws_http_header headers[] = { + { + .name = aws_byte_cursor_from_c_str("Content-Length"), + .value = aws_byte_cursor_from_c_str("16"), + }, + }; + + struct aws_http_message *request = aws_http_message_new_request(allocator); + ASSERT_NOT_NULL(request); + ASSERT_SUCCESS(aws_http_message_set_request_method(request, aws_byte_cursor_from_c_str("PUT"))); + ASSERT_SUCCESS(aws_http_message_set_request_path(request, aws_byte_cursor_from_c_str("/plan.txt"))); + ASSERT_SUCCESS(aws_http_message_add_header_array(request, headers, AWS_ARRAY_SIZE(headers))); + aws_http_message_set_body_stream(request, body_stream); + + struct client_stream_tester stream_tester; + ASSERT_SUCCESS(s_stream_tester_init(&stream_tester, &tester, request)); + + /* send head of request */ + testing_channel_run_currently_queued_tasks(&tester.testing_channel); + + /* Ensure the request can be destroyed after request is sent */ + aws_http_message_destroy(request); + aws_input_stream_release(body_stream); + + /* Something absurd */ + aws_http_stream_cancel(stream_tester.stream, AWS_ERROR_COND_VARIABLE_ERROR_UNKNOWN); + /* The second call will take not action */ + aws_http_stream_cancel(stream_tester.stream, AWS_ERROR_SUCCESS); + /* Wait for channel to finish shutdown */ + testing_channel_drain_queued_tasks(&tester.testing_channel); + /* check result, should not receive any body */ + const char *expected = "PUT /plan.txt HTTP/1.1\r\n" + "Content-Length: 16\r\n" + "\r\n"; + ASSERT_SUCCESS(testing_channel_check_written_messages_str(&tester.testing_channel, allocator, expected)); + + ASSERT_TRUE(stream_tester.complete); + ASSERT_INT_EQUALS(AWS_ERROR_COND_VARIABLE_ERROR_UNKNOWN, stream_tester.on_complete_error_code); + + /* clean up */ + client_stream_tester_clean_up(&stream_tester); + ASSERT_SUCCESS(s_tester_clean_up(&tester)); + return AWS_OP_SUCCESS; +} + /* When response has `connection: close` any further request body should not be sent. */ H1_CLIENT_TEST_CASE(h1_client_response_close_connection_before_request_finishes) { (void)ctx; diff --git a/tests/test_h2_client.c b/tests/test_h2_client.c index 74393188..093e89f9 100644 --- a/tests/test_h2_client.c +++ b/tests/test_h2_client.c @@ -4508,6 +4508,49 @@ TEST_CASE(h2_client_stream_reset_failed_before_activate_called) { return s_tester_clean_up(); } +TEST_CASE(h2_client_stream_cancel_stream) { + ASSERT_SUCCESS(s_tester_init(allocator, ctx)); + /* get connection preface and acks out of the way */ + ASSERT_SUCCESS(h2_fake_peer_send_connection_preface_default_settings(&s_tester.peer)); + ASSERT_SUCCESS(h2_fake_peer_decode_messages_from_testing_channel(&s_tester.peer)); + struct aws_http_message *request = aws_http2_message_new_request(allocator); + ASSERT_NOT_NULL(request); + + struct aws_http_header request_headers_src[] = { + DEFINE_HEADER(":method", "GET"), + DEFINE_HEADER(":scheme", "https"), + DEFINE_HEADER(":path", "/"), + }; + aws_http_message_add_header_array(request, request_headers_src, AWS_ARRAY_SIZE(request_headers_src)); + struct aws_http_make_request_options request_options = { + .self_size = sizeof(request_options), + .request = request, + }; + + struct client_stream_tester stream_tester; + ASSERT_SUCCESS(s_stream_tester_init(&stream_tester, request)); + testing_channel_drain_queued_tasks(&s_tester.testing_channel); + + /* Cancel the request */ + aws_http_stream_cancel(stream_tester.stream, AWS_ERROR_COND_VARIABLE_ERROR_UNKNOWN); + testing_channel_drain_queued_tasks(&s_tester.testing_channel); + + ASSERT_TRUE(aws_http_connection_is_open(s_tester.connection)); + ASSERT_TRUE(stream_tester.complete); + ASSERT_INT_EQUALS(AWS_ERROR_COND_VARIABLE_ERROR_UNKNOWN, stream_tester.on_complete_error_code); + /* validate that stream sent RST_STREAM */ + ASSERT_SUCCESS(h2_fake_peer_decode_messages_from_testing_channel(&s_tester.peer)); + struct h2_decoded_frame *rst_stream_frame = + h2_decode_tester_find_frame(&s_tester.peer.decode, AWS_H2_FRAME_T_RST_STREAM, 0, NULL); + /* But the error code is not the same as user was trying to send */ + ASSERT_UINT_EQUALS(AWS_HTTP2_ERR_CANCEL, rst_stream_frame->error_code); + + /* clean up */ + aws_http_message_release(request); + client_stream_tester_clean_up(&stream_tester); + return s_tester_clean_up(); +} + TEST_CASE(h2_client_stream_keeps_alive_for_cross_thread_task) { ASSERT_SUCCESS(s_tester_init(allocator, ctx));