From 4e74ab1e3702763e0b87bd1752f5a37c2f0400ac Mon Sep 17 00:00:00 2001 From: Dengke Tang Date: Fri, 16 Aug 2024 09:26:44 -0700 Subject: [PATCH] Fix websocket shutdown behavior (#483) The bug was introduced in [PR #474](https://github.com/awslabs/aws-c-http/pull/474/files#diff-ee776c7576cfff50a64158d59a6173ab9a0aa373150574aa9987b4f8726b58e3) - `is_writing_stopped = true` shouldn't be set directly, there's a helper function `s_stop_writing()` that ensures subsequent calls to `aws_websocket_send_frame()` will fail. Let's take a whole new approach these channel-shutdown-window-deadlock issues: - add `s_stop_reading_and_dont_block_shutdown()` function that sets `is_reading_stopped = true`, but also increments the read window so that channel shutdown won't deadlock. - Most places that were setting `is_reading_stopped = true` now use this helper instead - Revamp how `aws_channel_shutdown()` is called. Lots of channel behavior has changed since [this websocket code was written](https://github.com/awslabs/aws-c-http/pull/48). - If on the channel-thread, just call `aws_channel_shutdown()` - now that [aws_channel_shutdown()](https://github.com/awslabs/aws-c-io/pull/172) is always async, we don't need to defensively schedule a task to call it - If off-thread, use `s_schedule_channel_shutdown_from_offthead()` - now that this is only called from `aws_websocket_close()`, or when the refcount goes to zero, we can assume the user is OK if reading stops, and it can call `s_stop_reading_and_dont_block_shutdown()` on the way to shutting down. - Add the test to verify that send after close should fail Co-authored-by: Michael Graeb --- source/websocket.c | 125 ++++++++++++++++----------------- tests/CMakeLists.txt | 1 + tests/test_websocket_handler.c | 32 +++++++++ 3 files changed, 94 insertions(+), 64 deletions(-) diff --git a/source/websocket.c b/source/websocket.c index 8637081c..da5aedbb 100644 --- a/source/websocket.c +++ b/source/websocket.c @@ -43,7 +43,7 @@ struct aws_websocket { aws_websocket_on_incoming_frame_complete_fn *on_incoming_frame_complete; struct aws_channel_task move_synced_data_to_thread_task; - struct aws_channel_task shutdown_channel_task; + struct aws_channel_task shutdown_channel_from_offthread_task; struct aws_channel_task increment_read_window_task; struct aws_channel_task waiting_on_payload_stream_task; struct aws_channel_task close_timeout_task; @@ -85,7 +85,10 @@ struct aws_websocket { /* True when no more frames will be read, due to: * - a CLOSE frame was received * - decoder error - * - channel shutdown in read-dir */ + * - channel shutdown in read-dir + * - user calling aws_websocket_close() + * - user dropping the last refcount + */ bool is_reading_stopped; /* True when no more frames will be written, due to: @@ -124,9 +127,9 @@ struct aws_websocket { /* Error-code returned by aws_websocket_send_frame() when is_writing_stopped is true */ int send_frame_error_code; - /* Use a task to issue a channel shutdown. */ - int shutdown_channel_task_error_code; - bool is_shutdown_channel_task_scheduled; + /* Use a task to issue a channel shutdown from off-thread. */ + int shutdown_channel_from_offthread_task_error_code; + bool is_shutdown_channel_from_offthread_task_scheduled; bool is_move_synced_data_to_thread_task_scheduled; @@ -186,10 +189,13 @@ static bool s_midchannel_send_payload(struct aws_websocket *websocket, struct aw static void s_midchannel_send_complete(struct aws_websocket *websocket, int error_code, void *user_data); static void s_move_synced_data_to_thread_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); static void s_increment_read_window_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); -static void s_shutdown_channel_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); +static void s_shutdown_channel_from_offthread_task( + struct aws_channel_task *task, + void *arg, + enum aws_task_status status); static void s_waiting_on_payload_stream_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); static void s_close_timeout_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); -static void s_schedule_channel_shutdown(struct aws_websocket *websocket, int error_code); +static void s_schedule_channel_shutdown_from_offthread(struct aws_websocket *websocket, int error_code); static void s_shutdown_due_to_write_err(struct aws_websocket *websocket, int error_code); static void s_shutdown_due_to_read_err(struct aws_websocket *websocket, int error_code); static void s_stop_writing(struct aws_websocket *websocket, int send_frame_error_code); @@ -285,7 +291,10 @@ struct aws_websocket *aws_websocket_handler_new(const struct aws_websocket_handl websocket, "websocket_move_synced_data_to_thread"); aws_channel_task_init( - &websocket->shutdown_channel_task, s_shutdown_channel_task, websocket, "websocket_shutdown_channel"); + &websocket->shutdown_channel_from_offthread_task, + s_shutdown_channel_from_offthread_task, + websocket, + "websocket_shutdown_channel"); aws_channel_task_init( &websocket->increment_read_window_task, s_increment_read_window_task, @@ -377,7 +386,7 @@ static void s_websocket_on_refcount_zero(void *user_data) { AWS_LS_HTTP_WEBSOCKET, "id=%p: Websocket ref-count is zero, shut down if necessary.", (void *)websocket); /* Channel might already be shut down, but make sure */ - s_schedule_channel_shutdown(websocket, AWS_ERROR_SUCCESS); + s_schedule_channel_shutdown_from_offthread(websocket, AWS_ERROR_SUCCESS); /* Channel won't destroy its slots/handlers until its refcount reaches 0 */ aws_channel_release_hold(websocket->channel_slot->channel); @@ -897,6 +906,21 @@ static void s_complete_frame_list(struct aws_websocket *websocket, struct aws_li aws_linked_list_init(frames); } +/* Set is_reading_stopped = true, all further read data will be ignored. + * But also increment the read window, so that channel shutdown won't deadlock + * due to pending read-data in an upstream handler or the underlying OS socket. */ +static void s_stop_reading_and_dont_block_shutdown(struct aws_websocket *websocket) { + AWS_ASSERT(aws_channel_thread_is_callers_thread(websocket->channel_slot->channel)); + if (websocket->thread_data.is_reading_stopped) { + return; + } + + AWS_LOGF_TRACE(AWS_LS_HTTP_WEBSOCKET, "id=%p: Websocket will ignore any further read data.", (void *)websocket); + websocket->thread_data.is_reading_stopped = true; + + aws_channel_slot_increment_read_window(websocket->channel_slot, SIZE_MAX); +} + static void s_stop_writing(struct aws_websocket *websocket, int send_frame_error_code) { AWS_ASSERT(aws_channel_thread_is_callers_thread(websocket->channel_slot->channel)); AWS_ASSERT(send_frame_error_code != AWS_ERROR_SUCCESS); @@ -947,7 +971,7 @@ static void s_shutdown_due_to_write_err(struct aws_websocket *websocket, int err (void *)websocket, error_code, aws_error_name(error_code)); - s_schedule_channel_shutdown(websocket, error_code); + aws_channel_shutdown(websocket->channel_slot->channel, error_code); } } @@ -961,7 +985,7 @@ static void s_shutdown_due_to_read_err(struct aws_websocket *websocket, int erro error_code, aws_error_name(error_code)); - websocket->thread_data.is_reading_stopped = true; + s_stop_reading_and_dont_block_shutdown(websocket); /* If there's a current incoming frame, complete it with the specific error code. */ if (websocket->thread_data.current_incoming_frame) { @@ -969,10 +993,14 @@ static void s_shutdown_due_to_read_err(struct aws_websocket *websocket, int erro } /* Tell channel to shutdown (it's ok to call this redundantly) */ - s_schedule_channel_shutdown(websocket, error_code); + aws_channel_shutdown(websocket->channel_slot->channel, error_code); } -static void s_shutdown_channel_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) { +static void s_shutdown_channel_from_offthread_task( + struct aws_channel_task *task, + void *arg, + enum aws_task_status status) { + (void)task; if (status != AWS_TASK_STATUS_RUN_READY) { @@ -985,39 +1013,39 @@ static void s_shutdown_channel_task(struct aws_channel_task *task, void *arg, en /* BEGIN CRITICAL SECTION */ s_lock_synced_data(websocket); - error_code = websocket->synced_data.shutdown_channel_task_error_code; + error_code = websocket->synced_data.shutdown_channel_from_offthread_task_error_code; s_unlock_synced_data(websocket); /* END CRITICAL SECTION */ - websocket->thread_data.is_reading_stopped = true; - websocket->thread_data.is_writing_stopped = true; + + /* Stop reading, so that shutdown won't be blocked. + * If something off-thread is causing shutdown (aws_websocket_close(), refcount 0, etc), + * the user may never interact with the websocket again. We can't rely on them + * to keep the window open and prevent deadlock during shutdown. */ + s_stop_reading_and_dont_block_shutdown(websocket); aws_channel_shutdown(websocket->channel_slot->channel, error_code); - /* Increase the window size after shutdown starts, to prevent deadlock when data still pending in the upstream - * handler. */ - aws_channel_slot_increment_read_window(websocket->channel_slot, SIZE_MAX); } -/* Tell the channel to shut down. It is safe to call this multiple times. - * The call to aws_channel_shutdown() is delayed so that a user invoking aws_websocket_close doesn't - * have completion callbacks firing before the function call even returns */ -static void s_schedule_channel_shutdown(struct aws_websocket *websocket, int error_code) { +/* Tell the channel to shut down, from off-thread. It is safe to call this multiple times. */ +static void s_schedule_channel_shutdown_from_offthread(struct aws_websocket *websocket, int error_code) { bool schedule_shutdown = false; /* BEGIN CRITICAL SECTION */ s_lock_synced_data(websocket); - if (!websocket->synced_data.is_shutdown_channel_task_scheduled) { + if (!websocket->synced_data.is_shutdown_channel_from_offthread_task_scheduled) { schedule_shutdown = true; - websocket->synced_data.is_shutdown_channel_task_scheduled = true; - websocket->synced_data.shutdown_channel_task_error_code = error_code; + websocket->synced_data.is_shutdown_channel_from_offthread_task_scheduled = true; + websocket->synced_data.shutdown_channel_from_offthread_task_error_code = error_code; } s_unlock_synced_data(websocket); /* END CRITICAL SECTION */ if (schedule_shutdown) { - aws_channel_schedule_task_now(websocket->channel_slot->channel, &websocket->shutdown_channel_task); + aws_channel_schedule_task_now( + websocket->channel_slot->channel, &websocket->shutdown_channel_from_offthread_task); } } @@ -1038,14 +1066,13 @@ void aws_websocket_close(struct aws_websocket *websocket, bool free_scarce_resou return; } - /* TODO: aws_channel_shutdown() should let users specify error_code and "immediate" as separate parameters. - * Currently, any non-zero error_code results in "immediate" shutdown */ + /* TODO: aws_channel_shutdown() should let users specify error_code and "immediate" as separate parameters. */ int error_code = AWS_ERROR_SUCCESS; if (free_scarce_resources_immediately) { error_code = AWS_ERROR_HTTP_CONNECTION_CLOSED; } - s_schedule_channel_shutdown(websocket, error_code); + s_schedule_channel_shutdown_from_offthread(websocket, error_code); } static int s_handler_shutdown( @@ -1255,17 +1282,7 @@ static int s_handler_process_read_message( } if (websocket->thread_data.incoming_message_window_update > 0) { - err = aws_channel_slot_increment_read_window(slot, websocket->thread_data.incoming_message_window_update); - if (err) { - AWS_LOGF_ERROR( - AWS_LS_HTTP_WEBSOCKET, - "id=%p: Failed to increment read window after message processing, error %d (%s). Closing " - "connection.", - (void *)websocket, - aws_last_error(), - aws_error_name(aws_last_error())); - goto error; - } + aws_channel_slot_increment_read_window(slot, websocket->thread_data.incoming_message_window_update); } goto clean_up; @@ -1508,7 +1525,7 @@ static void s_complete_incoming_frame(struct aws_websocket *websocket, int error AWS_LS_HTTP_WEBSOCKET, "id=%p: Close frame received, any further data received will be ignored.", (void *)websocket); - websocket->thread_data.is_reading_stopped = true; + s_stop_reading_and_dont_block_shutdown(websocket); /* TODO: auto-close if there's a channel-handler to the right */ @@ -1598,37 +1615,17 @@ static int s_handler_increment_read_window( } if (increment != 0) { - int err = aws_channel_slot_increment_read_window(slot, increment); - if (err) { - goto error; - } + aws_channel_slot_increment_read_window(slot, increment); } return AWS_OP_SUCCESS; error: - websocket->thread_data.is_reading_stopped = true; /* Shutting down channel because I know that no one ever checks these errors */ s_shutdown_due_to_read_err(websocket, aws_last_error()); return AWS_OP_ERR; } -static void s_increment_read_window_action(struct aws_websocket *websocket, size_t size) { - AWS_ASSERT(aws_channel_thread_is_callers_thread(websocket->channel_slot->channel)); - - int err = aws_channel_slot_increment_read_window(websocket->channel_slot, size); - if (err) { - AWS_LOGF_ERROR( - AWS_LS_HTTP_WEBSOCKET, - "id=%p: Failed to increment read window, error %d (%s). Closing websocket.", - (void *)websocket, - aws_last_error(), - aws_error_name(aws_last_error())); - - s_schedule_channel_shutdown(websocket, aws_last_error()); - } -} - static void s_increment_read_window_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) { (void)task; @@ -1651,7 +1648,7 @@ static void s_increment_read_window_task(struct aws_channel_task *task, void *ar AWS_LOGF_TRACE( AWS_LS_HTTP_WEBSOCKET, "id=%p: Running task to increment read window by %zu.", (void *)websocket, size); - s_increment_read_window_action(websocket, size); + aws_channel_slot_increment_read_window(websocket->channel_slot, size); } void aws_websocket_increment_read_window(struct aws_websocket *websocket, size_t size) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dcd242d9..2cdb4789 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -226,6 +226,7 @@ add_test_case(websocket_handler_window_manual_increment) add_test_case(websocket_handler_window_manual_increment_off_thread) add_test_case(websocket_handler_sends_pong_automatically) add_test_case(websocket_handler_wont_send_pong_after_close_frame) +add_test_case(websocket_handler_send_frame_fails_if_websocket_closed) add_test_case(websocket_midchannel_sanity_check) add_test_case(websocket_midchannel_write_message) add_test_case(websocket_midchannel_write_multiple_messages) diff --git a/tests/test_websocket_handler.c b/tests/test_websocket_handler.c index 98db5130..779b7fc6 100644 --- a/tests/test_websocket_handler.c +++ b/tests/test_websocket_handler.c @@ -1883,6 +1883,38 @@ TEST_CASE(websocket_handler_wont_send_pong_after_close_frame) { return AWS_OP_SUCCESS; } +/* This is a regression test. If aws_websocket_close() leads to shutdown, + * then subsequent calls to aws_websocket_send_frame() should fail. */ +TEST_CASE(websocket_handler_send_frame_fails_if_websocket_closed) { + (void)ctx; + (void)ctx; + struct tester tester; + ASSERT_SUCCESS(s_tester_init(&tester, allocator)); + + /* Call aws_websocket_close() and wait for shutdown to complete */ + testing_channel_set_is_on_users_thread(&tester.testing_channel, false); + aws_websocket_close(tester.websocket, false); + testing_channel_set_is_on_users_thread(&tester.testing_channel, true); + + ASSERT_SUCCESS(s_drain_written_messages(&tester)); + ASSERT_TRUE(testing_channel_is_shutdown_completed(&tester.testing_channel)); + + /* aws_websocket_send_frame() should fail */ + struct aws_byte_cursor payload = aws_byte_cursor_from_c_str("bitter butter."); + struct send_tester send = { + .payload = payload, + .def = + { + .opcode = AWS_WEBSOCKET_OPCODE_PING, + .fin = true, + }, + }; + ASSERT_FAILS(s_send_frame(&tester, &send)); + ASSERT_UINT_EQUALS(AWS_ERROR_HTTP_WEBSOCKET_CLOSE_FRAME_SENT, aws_last_error()); + ASSERT_SUCCESS(s_tester_clean_up(&tester)); + return AWS_OP_SUCCESS; +} + TEST_CASE(websocket_midchannel_read_message) { (void)ctx; struct tester tester;